From 82199b8fe1ed77df2ff68e4edc73ee2e09baecc5 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 7 Oct 2023 11:44:18 +0800 Subject: [PATCH 001/216] Init commit for swbd (#1146) --- .../run-swbd-conformer-ctc-2023-08-26.sh | 44 + .github/workflows/run-swbd-conformer-ctc.yml | 84 ++ egs/swbd/ASR/.gitignore | 2 + egs/swbd/ASR/README.md | 25 + egs/swbd/ASR/RESULTS.md | 113 +++ egs/swbd/ASR/conformer_ctc/__init__.py | 0 egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 416 +++++++++ egs/swbd/ASR/conformer_ctc/conformer.py | 1 + egs/swbd/ASR/conformer_ctc/decode.py | 853 ++++++++++++++++++ egs/swbd/ASR/conformer_ctc/export.py | 163 ++++ egs/swbd/ASR/conformer_ctc/label_smoothing.py | 1 + egs/swbd/ASR/conformer_ctc/pretrained.py | 1 + egs/swbd/ASR/conformer_ctc/sclite_scoring.py | 148 +++ egs/swbd/ASR/conformer_ctc/subsampling.py | 1 + .../ASR/conformer_ctc/test_label_smoothing.py | 52 ++ .../ASR/conformer_ctc/test_subsampling.py | 48 + .../ASR/conformer_ctc/test_transformer.py | 1 + egs/swbd/ASR/conformer_ctc/train.py | 814 +++++++++++++++++ egs/swbd/ASR/conformer_ctc/transformer.py | 1 + egs/swbd/ASR/local/compile_hlg.py | 1 + egs/swbd/ASR/local/compile_lg.py | 1 + egs/swbd/ASR/local/compute_fbank_eval2000.py | 139 +++ egs/swbd/ASR/local/compute_fbank_swbd.py | 163 ++++ .../convert_transcript_words_to_tokens.py | 103 +++ egs/swbd/ASR/local/dict.patch | 380 ++++++++ .../ASR/local/display_manifest_statistics.py | 125 +++ egs/swbd/ASR/local/extend_segments.pl | 99 ++ egs/swbd/ASR/local/filter_cuts.py | 160 ++++ egs/swbd/ASR/local/filter_empty_text.py | 72 ++ egs/swbd/ASR/local/format_acronyms_dict.py | 118 +++ egs/swbd/ASR/local/generate_unique_lexicon.py | 98 ++ .../ASR/local/map_acronyms_transcripts.py | 60 ++ .../normalize_and_filter_supervisions.py | 283 ++++++ egs/swbd/ASR/local/normalize_eval2000.py | 234 +++++ egs/swbd/ASR/local/prepare_lang.py | 1 + egs/swbd/ASR/local/prepare_lang_bpe.py | 274 ++++++ .../ASR/local/prepare_lm_training_data.py | 1 + egs/swbd/ASR/local/rt03_data_prep.sh | 107 +++ egs/swbd/ASR/local/sort_lm_training_data.py | 141 +++ egs/swbd/ASR/local/swbd1_data_prep.sh | 128 +++ egs/swbd/ASR/local/swbd1_map_words.pl | 52 ++ egs/swbd/ASR/local/swbd1_prepare_dict.sh | 101 +++ egs/swbd/ASR/local/train_bpe_model.py | 102 +++ egs/swbd/ASR/local/validate_bpe_lexicon.py | 1 + egs/swbd/ASR/prepare.sh | 463 ++++++++++ egs/swbd/ASR/shared | 1 + egs/swbd/ASR/utils/filter_scp.pl | 87 ++ egs/swbd/ASR/utils/fix_data_dir.sh | 197 ++++ egs/swbd/ASR/utils/parse_options.sh | 97 ++ egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl | 27 + egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl | 38 + 51 files changed, 6622 insertions(+) create mode 100755 .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh create mode 100644 .github/workflows/run-swbd-conformer-ctc.yml create mode 100644 egs/swbd/ASR/.gitignore create mode 100644 egs/swbd/ASR/README.md create mode 100644 egs/swbd/ASR/RESULTS.md create mode 100644 egs/swbd/ASR/conformer_ctc/__init__.py create mode 100644 egs/swbd/ASR/conformer_ctc/asr_datamodule.py create mode 120000 egs/swbd/ASR/conformer_ctc/conformer.py create mode 100755 egs/swbd/ASR/conformer_ctc/decode.py create mode 100755 egs/swbd/ASR/conformer_ctc/export.py create mode 120000 egs/swbd/ASR/conformer_ctc/label_smoothing.py create mode 120000 egs/swbd/ASR/conformer_ctc/pretrained.py create mode 100755 egs/swbd/ASR/conformer_ctc/sclite_scoring.py create mode 120000 egs/swbd/ASR/conformer_ctc/subsampling.py create mode 100755 egs/swbd/ASR/conformer_ctc/test_label_smoothing.py create mode 100755 egs/swbd/ASR/conformer_ctc/test_subsampling.py create mode 120000 egs/swbd/ASR/conformer_ctc/test_transformer.py create mode 100755 egs/swbd/ASR/conformer_ctc/train.py create mode 120000 egs/swbd/ASR/conformer_ctc/transformer.py create mode 120000 egs/swbd/ASR/local/compile_hlg.py create mode 120000 egs/swbd/ASR/local/compile_lg.py create mode 100755 egs/swbd/ASR/local/compute_fbank_eval2000.py create mode 100755 egs/swbd/ASR/local/compute_fbank_swbd.py create mode 100755 egs/swbd/ASR/local/convert_transcript_words_to_tokens.py create mode 100644 egs/swbd/ASR/local/dict.patch create mode 100755 egs/swbd/ASR/local/display_manifest_statistics.py create mode 100755 egs/swbd/ASR/local/extend_segments.pl create mode 100755 egs/swbd/ASR/local/filter_cuts.py create mode 100755 egs/swbd/ASR/local/filter_empty_text.py create mode 100755 egs/swbd/ASR/local/format_acronyms_dict.py create mode 100755 egs/swbd/ASR/local/generate_unique_lexicon.py create mode 100755 egs/swbd/ASR/local/map_acronyms_transcripts.py create mode 100755 egs/swbd/ASR/local/normalize_and_filter_supervisions.py create mode 100755 egs/swbd/ASR/local/normalize_eval2000.py create mode 120000 egs/swbd/ASR/local/prepare_lang.py create mode 100755 egs/swbd/ASR/local/prepare_lang_bpe.py create mode 120000 egs/swbd/ASR/local/prepare_lm_training_data.py create mode 100755 egs/swbd/ASR/local/rt03_data_prep.sh create mode 100755 egs/swbd/ASR/local/sort_lm_training_data.py create mode 100755 egs/swbd/ASR/local/swbd1_data_prep.sh create mode 100755 egs/swbd/ASR/local/swbd1_map_words.pl create mode 100755 egs/swbd/ASR/local/swbd1_prepare_dict.sh create mode 100755 egs/swbd/ASR/local/train_bpe_model.py create mode 120000 egs/swbd/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/swbd/ASR/prepare.sh create mode 120000 egs/swbd/ASR/shared create mode 100755 egs/swbd/ASR/utils/filter_scp.pl create mode 100755 egs/swbd/ASR/utils/fix_data_dir.sh create mode 100755 egs/swbd/ASR/utils/parse_options.sh create mode 100755 egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl create mode 100755 egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh new file mode 100755 index 000000000..d8cc020e1 --- /dev/null +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/swbd/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-98.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + +for method in ctc-decoding 1best; do + log "$method" + + ./conformer_ctc/pretrained.py \ + --method $method \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --G $repo/data/lm/G_4_gram.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/workflows/run-swbd-conformer-ctc.yml b/.github/workflows/run-swbd-conformer-ctc.yml new file mode 100644 index 000000000..842691d38 --- /dev/null +++ b/.github/workflows/run-swbd-conformer-ctc.yml @@ -0,0 +1,84 @@ +# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-swbd-conformer_ctc + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +concurrency: + group: run-swbd-conformer_ctc-${{ github.ref }} + cancel-in-progress: true + +jobs: + run-swbd-conformer_ctc: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'swbd' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh diff --git a/egs/swbd/ASR/.gitignore b/egs/swbd/ASR/.gitignore new file mode 100644 index 000000000..11d674922 --- /dev/null +++ b/egs/swbd/ASR/.gitignore @@ -0,0 +1,2 @@ +switchboard_word_alignments.tar.gz +./swb_ms98_transcriptions/ diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md new file mode 100644 index 000000000..13b27815a --- /dev/null +++ b/egs/swbd/ASR/README.md @@ -0,0 +1,25 @@ +# Switchboard + +The Switchboard-1 Telephone Speech Corpus (LDC97S62) consists of approximately 260 hours of speech and was originally collected by Texas Instruments in 1990-1, under DARPA sponsorship. The first release of the corpus was published by NIST and distributed by the LDC in 1992-3. Since that release, a number of corrections have been made to the data files as presented on the original CD-ROM set and all copies of the first pressing have been distributed. + +Switchboard is a collection of about 2,400 two-sided telephone conversations among 543 speakers (302 male, 241 female) from all areas of the United States. A computer-driven robot operator system handled the calls, giving the caller appropriate recorded prompts, selecting and dialing another person (the callee) to take part in a conversation, introducing a topic for discussion and recording the speech from the two subjects into separate channels until the conversation was finished. About 70 topics were provided, of which about 50 were used frequently. Selection of topics and callees was constrained so that: (1) no two speakers would converse together more than once and (2) no one spoke more than once on a given topic. + +(The above introduction is from the [LDC Switchboard-1 Release 2 webpage](https://catalog.ldc.upenn.edu/LDC97S62).) + + +## Performance Record +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +See [RESULTS](/egs/swbd/ASR/RESULTS.md) for details. + +## Credit + +The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ctc` recipe in icefall. + +A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored by myself to incorporate with Lhotse and Icefall. + +Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). + +The `sclite_scoring.py` is from the GigaSpeech recipe for post processing and glm-like scoring, which is definitely not an elegant stuff to do. diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md new file mode 100644 index 000000000..f3a22c444 --- /dev/null +++ b/egs/swbd/ASR/RESULTS.md @@ -0,0 +1,113 @@ +## Results +### Switchboard BPE training results (Conformer-CTC) + +#### 2023-09-04 + +The best WER, as of 2023-09-04, for the Switchboard is below + +Results using attention decoder are given as: + +| | eval2000-swbd | eval2000-callhome | eval2000-avg | +|--------------------------------|-----------------|---------------------|--------------| +| `conformer_ctc` | 9.48 | 17.73 | 13.67 | + +Decoding results and models can be found here: +https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 +#### 2023-06-27 + +The best WER, as of 2023-06-27, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 30.80 | 32.29 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.1 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.9 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ + --num-epochs 100 +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 99 \ + --avg 10 \ + --max-duration 50 +``` + +#### 2023-06-26 + +The best WER, as of 2023-06-26, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.3 | 2.5 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.7 | 1.3 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 55 \ + --avg 1 \ + --max-duration 50 +``` + +For your reference, the nbest oracle WERs are: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 25.64 | 26.84 | diff --git a/egs/swbd/ASR/conformer_ctc/__init__.py b/egs/swbd/ASR/conformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 000000000..59d73c660 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,416 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class SwitchBoardAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train dataloader, + but there can be multiple test dataloaders (e.g. SwitchBoard rt03 + and eval2000). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=50, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + buffer_size=50000, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_all_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(last=166844) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(first=300) + + @lru_cache() + def test_eval2000_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get eval2000 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "eval2000" / "eval2000_cuts_all.jsonl.gz" + ) + + @lru_cache() + def test_rt03_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get rt03 cuts") + return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_rt03.jsonl.gz") diff --git a/egs/swbd/ASR/conformer_ctc/conformer.py b/egs/swbd/ASR/conformer_ctc/conformer.py new file mode 120000 index 000000000..d1f4209d7 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py new file mode 100755 index 000000000..2bbade374 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -0,0 +1,853 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer + +from sclite_scoring import asr_text_post_processing + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_rnn_lm, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=98, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + if test_set_name == "test-eval2000": + subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} + elif test_set_name == "test-rt03": + subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} + else: + raise NotImplementedError(f"No implementation for testset {test_set_name}") + for subset, prefix in subsets.items(): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" + results = post_processing(results) + results = ( + sorted(list(filter(lambda x: x[0].startswith(prefix), results))) + if subset != "avg" + else sorted(results) + ) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{subset}-{key}", + results, + enable_log=enable_log, + sclite_mode=True, + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{subset}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}-{}, WER of different settings are:\n".format( + test_set_name, subset + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + + # we need cut ids to display recognition results. + args.return_cuts = True + switchboard = SwitchBoardAsrDataModule(args) + + test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions( + keep_all_channels=True + ) + # test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( + # keep_all_channels=True + # ) + + test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) + # test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) + + # test_sets = ["test-eval2000", "test-rt03"] + # test_dl = [test_eval2000_dl, test_rt03_dl] + test_sets = ["test-eval2000"] + test_dl = [test_eval2000_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py new file mode 100755 index 000000000..1bb6277ad --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=98, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/conformer_ctc/label_smoothing.py b/egs/swbd/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py new file mode 120000 index 000000000..526bc9678 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py new file mode 100755 index 000000000..0383c4d71 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", + "MHM", + "HUM", + "AW", + "OH", + "HMM", + "UMM", +] +unk_tags = ["", ""] +switchboard_garbage_utterance_tags = [ + "[LAUGHTER]", + "[NOISE]", + "[VOCALIZED-NOISE]", + "[SILENCE]", +] +non_scoring_words = ( + conversational_filler + unk_tags + switchboard_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove non-scoring words from evaluation + remaining_words = [] + text_split = text.split() + word_to_skip = 0 + for idx, word in enumerate(text_split): + if word_to_skip > 0: + word_to_skip -= 1 + continue + if word in non_scoring_words: + continue + elif word == "CANCELLED": + remaining_words.append("CANCELED") + continue + elif word == "AIRFLOW": + remaining_words.append("AIR") + remaining_words.append("FLOW") + continue + elif word == "PHD": + remaining_words.append("P") + remaining_words.append("H") + remaining_words.append("D") + continue + elif word == "UCLA": + remaining_words.append("U") + remaining_words.append("C") + remaining_words.append("L") + remaining_words.append("A") + continue + elif word == "ONTO": + remaining_words.append("ON") + remaining_words.append("TO") + continue + elif word == "DAY": + try: + if text_split[idx + 1] == "CARE": + remaining_words.append("DAYCARE") + word_to_skip = 1 + except: + remaining_words.append(word) + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/conformer_ctc/subsampling.py b/egs/swbd/ASR/conformer_ctc/subsampling.py new file mode 120000 index 000000000..16354dc73 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py new file mode 100755 index 000000000..5d4438fd1 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion + +import torch +from label_smoothing import LabelSmoothingLoss + +torch_ver = LooseVersion(torch.__version__) + + +def test_with_torch_label_smoothing_loss(): + if torch_ver < LooseVersion("1.10.0"): + print(f"Current torch version: {torch_ver}") + print("Please use torch >= 1.10 to run this test - skipping") + return + torch.manual_seed(20211105) + x = torch.rand(20, 30, 5000) + tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2]) + for reduction in ["none", "sum", "mean"]: + custom_loss_func = LabelSmoothingLoss( + ignore_index=-1, label_smoothing=0.1, reduction=reduction + ) + custom_loss = custom_loss_func(x, tgt) + + torch_loss_func = torch.nn.CrossEntropyLoss( + ignore_index=-1, reduction=reduction, label_smoothing=0.1 + ) + torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1)) + assert torch.allclose(custom_loss, torch_loss) + + +def main(): + test_with_torch_label_smoothing_loss() + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/test_subsampling.py b/egs/swbd/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 000000000..81fa234dd --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from subsampling import Conv2dSubsampling, VggSubsampling + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/swbd/ASR/conformer_ctc/test_transformer.py b/egs/swbd/ASR/conformer_ctc/test_transformer.py new file mode 120000 index 000000000..8b0990ec6 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py new file mode 100755 index 000000000..7f1eebbcf --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./conformer_ctc/train.py \ + --exp-dir ./conformer_ctc/exp \ + --world-size 4 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=98, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in str(params.lang_dir): + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + switchboard = SwitchBoardAsrDataModule(args) + + train_cuts = switchboard.train_all_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = switchboard.train_dataloaders(train_cuts) + + valid_cuts = switchboard.dev_cuts() + valid_dl = switchboard.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/transformer.py b/egs/swbd/ASR/conformer_ctc/transformer.py new file mode 120000 index 000000000..1c3f43fcf --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/swbd/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/swbd/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py new file mode 100755 index 000000000..d446e8ff3 --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_eval2000.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + manifests = { + "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", + } + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = CutSet.from_file(manifests[prefix]).resample(16000) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_switchboard( + dir_name="eval2000", + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py new file mode 100755 index 000000000..dd82220c0 --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + parser.add_argument( + "--split-index", + type=int, + required=True, + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + split_index: int, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}_split16") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + split_dir = Path("data/manifests/swbd_split16/") + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = ( + f"{prefix}_cuts_{partition}.{str(split_index).zfill(2)}.{suffix}" + ) + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = ( + CutSet.from_file( + split_dir + / f"swbd_train_all_trimmed.{str(split_index).zfill(2)}.jsonl.gz" + ) + .resample(16000) + .to_eager() + .filter(lambda c: c.duration > 2.0) + ) + + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}_{str(split_index).zfill(2)}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, + min_duration=None, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + compute_fbank_switchboard( + dir_name="swbd", + split_index=args.split_index, + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 000000000..a8d5117c9 --- /dev/null +++ b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument("--oov", type=str, default="", help="The OOV word.") + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = line.strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/dict.patch b/egs/swbd/ASR/local/dict.patch new file mode 100644 index 000000000..12c63d612 --- /dev/null +++ b/egs/swbd/ASR/local/dict.patch @@ -0,0 +1,380 @@ +1d0 +< file: $SWB/data/dictionary/sw-ms98-dict.text +8645a8646 +> uh-hum ah m hh ah m +9006c9007 +< April ey p r ih l +--- +> April ey p r ax l +9144d9144 +< B ay zh aa n iy z +9261c9261 +< Battle b ae t el +--- +> Battle b ae t ax l +10014a10015 +> Chevy sh eh v iy +10211a10213 +> Colorado k ao l ax r aa d ow +10212a10215 +> Colorado' k ao l ax r aa d ow z +10370c10373 +< Creek k r ih k +--- +> Creek k r iy k +10889a10893 +> Eleven ax l eh v ih n +10951c10955 +< Erie ih r iy +--- +> Erie iy r iy +11183c11187 +< Forever f ax r eh v er +--- +> Forever f er eh v er +11231a11236 +> Friday f r ay d iy +11744a11750 +> History hh ih s t r iy +12004a12011,12012 +> Israel ih z r ih l +> Israel's ih z r ih l z +12573a12582 +> Lincoln l ih ng k ih n +12574a12584 +> Lincolns l ih ng k ih n z +13268c13278 +< NAACP eh ey ey s iy p iy +--- +> NAACP eh n ey ey s iy p iy +13286c13296 +< NIT eh ay t iy +--- +> NIT eh n ay t iy +13292c13302 +< NTSC eh t iy eh s s iy +--- +> NTSC eh n t iy eh s s iy +14058a14069 +> Quarter k ow r t er +14059a14071 +> Quarterback k ow r t er b ae k +14060a14073 +> Quarters k ow r t er z +14569a14583 +> Science s ay n s +15087a15102 +> Sunday s ah n d iy +15088a15104 +> Sunday's s ah n d iy z +15089a15106 +> Sundays s ah n d iy z +15290,15291c15307,15308 +< Texan t eh k sh ih n +< Texan's t eh k sh ih n s +--- +> Texan t eh k s ih n +> Texan's t eh k s ih n s +15335a15353 +> Thousands th aw z ih n z +15739c15757 +< Waco w ae k ow +--- +> Waco w ey k ow +15841a15860 +> Weekends w iy k eh n z +16782a16802 +> acceptable eh k s eh p ax b ax l +16833a16854 +> accounting ax k aw n ih ng +16948a16970 +> address ax d r eh s +17281a17304 +> already aa r d iy +17315a17339 +> am m +17709a17734 +> asked ae s t +17847a17873 +> attorney ih t er n iy +17919a17946 +> autopilot ao t ow p ay l ih t +17960a17988 +> awfully ao f l iy +18221a18250 +> basketball b ae s k ax b ao l +18222a18252 +> basketball's b ae s k ax b ao l z +18302a18333 +> become b ah k ah m +18303a18335 +> becomes b iy k ah m z +18344a18377 +> began b ax g en n +18817c18850 +< bottle b aa t el +--- +> bottle b aa t ax l +19332,19333c19365,19367 +< camera's k ae m ax r ax z +< cameras k ae m ax r ax z +--- +> camera k ae m r ax +> camera's k ae m r ax z +> cameras k ae m r ax z +19411a19446 +> capital k ae p ax l +19505a19541 +> carrying k ae r ih ng +20316a20353,20354 +> combination k aa m ih n ey sh ih n +> combinations k aa m ih n ey sh ih n z +20831a20870 +> contracts k aa n t r ae k s +21010a21050 +> costs k ao s +21062a21103 +> county k aw n iy +21371a21413 +> cultural k ao l ch ax r ax l +21372a21415 +> culturally k ao l ch ax r ax l iy +21373a21417 +> culture k ao l ch er +21375a21420 +> cultures k ao l ch er z +21543a21589 +> data d ey t ax +22097a22144 +> differently d ih f ax r ih n t l iy +22972a23020 +> effects ax f eh k t s +23016a23065 +> election ax l eh k sh ih n +23018a23068 +> elections ax l eh k sh ih n z +23052a23103 +> eleven ax l eh v ih n +23242a23294 +> enjoyable ae n jh oy ax b ax l +23248a23301 +> enjoys ae n jh oy z +23293a23347 +> entire ih n t ay r +23295a23350,23351 +> entirely ih n t ay r l iy +> entirety ih n t ay r t iy +23745a23802 +> extra eh k s t er +23818a23876 +> facts f ae k s +24508c24566 +< forever f ax r eh v er +--- +> forever f er eh v er +24514c24572 +< forget f ow r g eh t +--- +> forget f er r g eh t +24521a24580 +> forgot f er r g aa t +24522a24582 +> forgotten f er r g aa t ax n +24563a24624 +> forward f ow er d +24680a24742 +> frightening f r ay t n ih ng +24742a24805 +> full-time f ax l t ay m +24862a24926 +> garage g r aa jh +25218a25283 +> grandmother g r ae m ah dh er +25790a25856 +> heavily hh eh v ax l iy +25949a26016 +> history hh ih s t r iy +26038a26106 +> honestly aa n ax s t l iy +26039a26108 +> honesty aa n ax s t iy +26099a26169 +> horror hh ow r +26155a26226 +> houses hh aw z ih z +26184c26255 +< huh-uh hh ah hh ah +--- +> huh-uh ah hh ah +26189c26260 +< hum-um hh m hh m +--- +> hum-um ah m hh ah m +26236a26308 +> hunting hh ah n ih ng +26307a26380,26381 +> ideal ay d iy l +> idealist ay d iy l ih s t +26369a26444 +> imagine m ae jh ih n +26628a26704 +> individuals ih n d ih v ih jh ax l z +26968a27045 +> interest ih n t r ih s t +27184a27262 +> it'd ih d +27702a27781 +> lead l iy d +28378a28458 +> mandatory m ae n d ih t ow r iy +28885a28966 +> minute m ih n ih t +29167a29249 +> mountains m aw t n z +29317a29400 +> mysteries m ih s t r iy z +29318a29402 +> mystery m ih s t r iy +29470a29555 +> nervous n er v ih s +29578,29580c29663,29665 +< nobody n ow b aa d iy +< nobody'll n ow b aa d iy l +< nobody's n ow b aa d iy z +--- +> nobody n ow b ah d iy +> nobody'll n ow b ah d iy l +> nobody's n ow b ah d iy z +29712a29798 +> nuclear n uw k l iy r +29938a30025 +> onto aa n t ax +30051a30139 +> originally ax r ih jh ax l iy +30507a30596 +> particularly p er t ih k y ax l iy +30755a30845 +> perfectly p er f ih k l iy +30820a30911 +> personally p er s n ax l iy +30915a31007 +> physically f ih z ih k l iy +30986a31079 +> pilot p ay l ih t +30987a31081 +> pilot's p ay l ih t s +31227a31322 +> police p l iy s +31513a31609 +> prefer p er f er +31553a31650 +> prepare p r ax p ey r +31578a31676 +> prescription p er s k r ih p sh ih n +31579a31678 +> prescriptions p er s k r ih p sh ih n z +31770a31870 +> products p r aa d ax k s +31821a31922 +> projects p r aa jh eh k s +31908a32010 +> protect p er t eh k t +31909a32012 +> protected p er t eh k t ih d +31911a32015 +> protection p er t eh k sh ih n +31914a32019 +> protection p er t eh k t ih v +32149a32255 +> quarter k ow r t er +32414a32521 +> read r iy d +32785a32893 +> rehabilitation r iy ax b ih l ih t ey sh ih n +33150a33259 +> resource r ih s ow r s +33151a33261 +> resources r iy s ow r s ih z +33539c33649 +< roots r uh t s +--- +> roots r uw t s +33929a34040 +> science s ay n s +34315a34427 +> seventy s eh v ih n iy +34319,34320c34431,34432 +< severe s ax v iy r +< severely s ax v iy r l iy +--- +> severe s ih v iy r +> severely s ih v iy r l iy +35060a35173 +> software s ao f w ey r +35083a35197 +> solid s ao l ih d +35084a35199 +> solidly s ao l ih d l iy +35750a35866 +> stood s t ih d +35854a35971 +> strictly s t r ih k l iy +35889c36006 +< stronger s t r ao ng er +--- +> stronger s t r ao ng g er +36192a36310,36311 +> supposed s p ow z +> supposed s p ow s +36510a36630 +> tastes t ey s +36856a36977 +> thoroughly th er r l iy +36866a36988 +> thousands th aw z ih n z +37081c37203 +< toots t uh t s +--- +> toots t uw t s +37157a37280 +> toward t w ow r d +37158a37282 +> towards t w ow r d z +37564a37689 +> twenties t w eh n iy z +37565a37691 +> twentieth t w eh n iy ih th +37637a37764 +> unacceptable ah n ae k s eh p ax b ax l +37728a37856 +> understand ah n d er s t ae n +37860a37989 +> unless ih n l eh s +38040a38170 +> use y uw z +38049a38180 +> uses y uw z ih z +38125a38257 +> various v ah r iy ih s +38202a38335 +> versus v er s ih z +38381c38514 +< wacko w ae k ow +--- +> wacko w ey k ow +38455c38588 +< wanna w aa n ax +--- +> wanna w ah n ax +38675c38808 +< whatnot w ah t n aa t +--- +> whatnot w aa t n aa t +38676a38810 +> whatsoever w aa t s ow eh v er +38890c39024 +< wok w aa k +--- +> wok w ao k +38910a39045 +> wondering w ah n d r ih ng diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..9aa204863 --- /dev/null +++ b/egs/swbd/ASR/local/display_manifest_statistics.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + # path = "./data/fbank/swbd_cuts_rt03.jsonl.gz" + path = "./data/fbank/eval2000/eval2000_cuts_all.jsonl.gz" + # path = "./data/fbank/swbd_cuts_all.jsonl.gz" + + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Training Cut statistics: +╒═══════════════════════════╤═══════════╕ +│ Cuts count: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Total duration (hh:mm:ss) │ 281:01:26 │ +├───────────────────────────┼───────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼───────────┤ +│ std │ 3.3 │ +├───────────────────────────┼───────────┤ +│ min │ 2.0 │ +├───────────────────────────┼───────────┤ +│ 25% │ 3.2 │ +├───────────────────────────┼───────────┤ +│ 50% │ 5.2 │ +├───────────────────────────┼───────────┤ +│ 75% │ 8.3 │ +├───────────────────────────┼───────────┤ +│ 99% │ 14.4 │ +├───────────────────────────┼───────────┤ +│ 99.5% │ 14.7 │ +├───────────────────────────┼───────────┤ +│ 99.9% │ 15.0 │ +├───────────────────────────┼───────────┤ +│ max │ 57.5 │ +├───────────────────────────┼───────────┤ +│ Recordings available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Features available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Supervisions available: │ 167244 │ +╘═══════════════════════════╧═══════════╛ +Speech duration statistics: +╒══════════════════════════════╤═══════════╤══════════════════════╕ +│ Total speech duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total speaking time duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧═══════════╧══════════════════════╛ + +Eval2000 Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 03:37:13 │ +├───────────────────────────┼──────────┤ +│ mean │ 2.9 │ +├───────────────────────────┼──────────┤ +│ std │ 2.6 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 4.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 12.6 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 13.7 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 14.7 │ +├───────────────────────────┼──────────┤ +│ max │ 15.5 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 4473 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 03:37:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 03:37:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +""" diff --git a/egs/swbd/ASR/local/extend_segments.pl b/egs/swbd/ASR/local/extend_segments.pl new file mode 100755 index 000000000..e8b4894d5 --- /dev/null +++ b/egs/swbd/ASR/local/extend_segments.pl @@ -0,0 +1,99 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter + +if (@ARGV != 1 || !($ARGV[0] =~ m/^-?\d+\.?\d*$/ && $ARGV[0] >= 0)) { + print STDERR "Usage: extend_segments.pl time-in-seconds segments.extended \n" . + "e.g. extend_segments.pl 0.25 segments.2\n" . + "This command modifies a segments file, with lines like\n" . + " \n" . + "by extending the beginning and end of each segment by a certain\n" . + "length of time. This script makes sure the output segments do not\n" . + "overlap as a result of this time-extension, and that there are no\n" . + "negative times in the output.\n"; + exit 1; +} + +$extend = $ARGV[0]; + +@all_lines = (); + +while () { + chop; + @A = split(" ", $_); + if (@A != 4) { + die "invalid line in segments file: $_"; + } + $line = @all_lines; # current number of lines. + ($utt_id, $reco_id, $start_time, $end_time) = @A; + + push @all_lines, [ $utt_id, $reco_id, $start_time, $end_time ]; # anonymous array. + if (! defined $lines_for_reco{$reco_id}) { + $lines_for_reco{$reco_id} = [ ]; # push new anonymous array. + } + push @{$lines_for_reco{$reco_id}}, $line; +} + +foreach $reco_id (keys %lines_for_reco) { + $ref = $lines_for_reco{$reco_id}; + @line_numbers = sort { ${$all_lines[$a]}[2] <=> ${$all_lines[$b]}[2] } @$ref; + + + { + # handle start of earliest segment as a special case. + $l0 = $line_numbers[0]; + $tstart = ${$all_lines[$l0]}[2] - $extend; + if ($tstart < 0.0) { $tstart = 0.0; } + ${$all_lines[$l0]}[2] = $tstart; + } + { + # handle end of latest segment as a special case. + $lN = $line_numbers[$#line_numbers]; + $tend = ${$all_lines[$lN]}[3] + $extend; + ${$all_lines[$lN]}[3] = $tend; + } + for ($i = 0; $i < $#line_numbers; $i++) { + $ln = $line_numbers[$i]; + $ln1 = $line_numbers[$i+1]; + $tend = ${$all_lines[$ln]}[3]; # end of earlier segment. + $tstart = ${$all_lines[$ln1]}[2]; # start of later segment. + if ($tend > $tstart) { + $utt1 = ${$all_lines[$ln]}[0]; + $utt2 = ${$all_lines[$ln1]}[0]; + print STDERR "Warning: for utterances $utt1 and $utt2, segments " . + "already overlap; leaving these times unchanged.\n"; + } else { + $my_extend = $extend; + $max_extend = 0.5 * ($tstart - $tend); + if ($my_extend > $max_extend) { $my_extend = $max_extend; } + $tend += $my_extend; + $tstart -= $my_extend; + ${$all_lines[$ln]}[3] = $tend; + ${$all_lines[$ln1]}[2] = $tstart; + } + } +} + +# leave the numbering of the lines unchanged. +for ($l = 0; $l < @all_lines; $l++) { + $ref = $all_lines[$l]; + ($utt_id, $reco_id, $start_time, $end_time) = @$ref; + printf("%s %s %.2f %.2f\n", $utt_id, $reco_id, $start_time, $end_time); +} + +__END__ + +# testing below. + +# ( echo a1 A 0 1; echo a2 A 3 4; echo b1 B 0 1; echo b2 B 2 3 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.00 +a2 A 2.00 5.00 +b1 B 0.00 1.50 +b2 B 1.50 4.00 +# ( echo a1 A 0 2; echo a2 A 1 3 ) | local/extend_segments.pl 1.0 +Warning: for utterances a1 and a2, segments already overlap; leaving these times unchanged. +a1 A 0.00 2.00 +a2 A 1.00 4.00 +# ( echo a1 A 0 2; echo a2 A 5 6; echo a3 A 3 4 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.50 +a2 A 4.50 7.00 +a3 A 2.50 4.50 diff --git a/egs/swbd/ASR/local/filter_cuts.py b/egs/swbd/ASR/local/filter_cuts.py new file mode 100755 index 000000000..fbcc9e24a --- /dev/null +++ b/egs/swbd/ASR/local/filter_cuts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ + --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py new file mode 100755 index 000000000..6b3316800 --- /dev/null +++ b/egs/swbd/ASR/local/filter_empty_text.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path +import logging +from typing import List + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--kaldi-data-dir", + type=Path, + required=True, + help="Path to the kaldi data dir", + ) + + return parser.parse_args() + + +def load_segments(path: Path): + segments = {} + with open(path, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + utt_id, rec_id, start, end = line.split() + segments[utt_id] = line + return segments + + +def filter_text(path: Path): + with open(path, "r") as f: + lines = f.readlines() + return list(filter(lambda x: len(x.strip().split()) > 1, lines)) + + +def write_segments(path: Path, texts: List[str]): + with open(path, "w") as f: + f.writelines(texts) + + +def main(): + args = get_args() + orig_text_dict = filter_text(args.kaldi_data_dir / "text") + write_segments(args.kaldi_data_dir / "text", orig_text_dict) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() + + logging.info("Empty lines filtered") diff --git a/egs/swbd/ASR/local/format_acronyms_dict.py b/egs/swbd/ASR/local/format_acronyms_dict.py new file mode 100755 index 000000000..fa598dd03 --- /dev/null +++ b/egs/swbd/ASR/local/format_acronyms_dict.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd dict to fisher convention +# IBM to i._b._m. +# BBC to b._b._c. +# BBCs to b._b._c.s +# BBC's to b._b._c.'s + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input lexicon", required=True) +parser.add_argument("-o", "--output", help="Output lexicon", required=True) +parser.add_argument( + "-L", "--Letter", help="Input single letter pronunciation", required=True +) +parser.add_argument("-M", "--Map", help="Output acronyms mapping", required=True) +args = parser.parse_args() + + +fin_lex = open(args.input, "r") +fin_Letter = open(args.Letter, "r") +fout_lex = open(args.output, "w") +fout_map = open(args.Map, "w") + +# Initialise single letter dictionary +dict_letter = {} +for single_letter_lex in fin_Letter: + items = single_letter_lex.split() + dict_letter[items[0]] = single_letter_lex[len(items[0]) + 1 :].strip() +fin_Letter.close() +# print dict_letter + +for lex in fin_lex: + items = lex.split() + word = items[0] + lexicon = lex[len(items[0]) + 1 :].strip() + # find acronyms from words with only letters and ' + pre_match = re.match(r"^[A-Za-z]+$|^[A-Za-z]+\'s$|^[A-Za-z]+s$", word) + if pre_match: + # find if words in the form of xxx's is acronym + if word[-2:] == "'s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-2] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".'s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxxs is acronym + elif word[-1] == "s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-1] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxx (not ended with 's or s) is acronym + elif word.find("'") == -1 and word[-1] != "s": + acronym_lexicon = "" + for w in word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + word[-1].lower() + "." + acronym_mapped_back = acronym_mapped_back + word[-1].lower() + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + else: + fout_lex.write(lex) + + else: + fout_lex.write(lex) diff --git a/egs/swbd/ASR/local/generate_unique_lexicon.py b/egs/swbd/ASR/local/generate_unique_lexicon.py new file mode 100755 index 000000000..3459c2f5a --- /dev/null +++ b/egs/swbd/ASR/local/generate_unique_lexicon.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input a lexicon.txt and output a new lexicon, +in which each word has a unique pronunciation. + +The way to do this is to keep only the first pronunciation of a word +in lexicon.txt. +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +from icefall.lexicon import read_lexicon, write_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + This file will generate a new file uniq_lexicon.txt + in it. + """, + ) + + return parser.parse_args() + + +def filter_multiple_pronunications( + lexicon: List[Tuple[str, List[str]]] +) -> List[Tuple[str, List[str]]]: + """Remove multiple pronunciations of words from a lexicon. + + If a word has more than one pronunciation in the lexicon, only + the first one is kept, while other pronunciations are removed + from the lexicon. + + Args: + lexicon: + The input lexicon, containing a list of (word, [p1, p2, ..., pn]), + where "p1, p2, ..., pn" are the pronunciations of the "word". + Returns: + Return a new lexicon where each word has a unique pronunciation. + """ + seen = set() + ans = [] + + for word, tokens in lexicon: + if word in seen: + continue + seen.add(word) + ans.append((word, tokens)) + return ans + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + lexicon_filename = lang_dir / "lexicon.txt" + + in_lexicon = read_lexicon(lexicon_filename) + + out_lexicon = filter_multiple_pronunications(in_lexicon) + + write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) + + logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") + logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/map_acronyms_transcripts.py b/egs/swbd/ASR/local/map_acronyms_transcripts.py new file mode 100755 index 000000000..ba02aaec3 --- /dev/null +++ b/egs/swbd/ASR/local/map_acronyms_transcripts.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd transcript to fisher convention +# according to first two columns in the input acronyms mapping + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input transcripts", required=True) +parser.add_argument("-o", "--output", help="Output transcripts", required=True) +parser.add_argument("-M", "--Map", help="Input acronyms mapping", required=True) +args = parser.parse_args() + +fin_map = open(args.Map, "r") +dict_acronym = {} +dict_acronym_noi = {} # Mapping of acronyms without I, i +for pair in fin_map: + items = pair.split("\t") + dict_acronym[items[0]] = items[1] + dict_acronym_noi[items[0]] = items[1] +fin_map.close() +del dict_acronym_noi["I"] +del dict_acronym_noi["i"] + + +fin_trans = open(args.input, "r") +fout_trans = open(args.output, "w") +for line in fin_trans: + items = line.split() + L = len(items) + # First pass mapping to map I as part of acronym + for i in range(L): + if items[i] == "I": + x = 0 + while i - 1 - x >= 0 and re.match(r"^[A-Z]$", items[i - 1 - x]): + x += 1 + + y = 0 + while i + 1 + y < L and re.match(r"^[A-Z]$", items[i + 1 + y]): + y += 1 + + if x + y > 0: + for bias in range(-x, y + 1): + items[i + bias] = dict_acronym[items[i + bias]] + + # Second pass mapping (not mapping 'i' and 'I') + for i in range(len(items)): + if items[i] in dict_acronym_noi.keys(): + items[i] = dict_acronym_noi[items[i]] + sentence = " ".join(items[1:]) + fout_trans.write(items[0] + " " + sentence.lower() + "\n") + +fin_trans.close() +fout_trans.close() diff --git a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py new file mode 100755 index 000000000..20ab90caf --- /dev/null +++ b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +# replacement function to convert lowercase letter to uppercase +def to_upper(match_obj): + if match_obj.group() is not None: + return match_obj.group().upper() + + +def insert_groups_and_capitalize_3(match): + return f"{match.group(1)} {match.group(2)} {match.group(3)}".upper() + + +def insert_groups_and_capitalize_2(match): + return f"{match.group(1)} {match.group(2)}".upper() + + +def insert_groups_and_capitalize_1(match): + return f"{match.group(1)}".upper() + + +def insert_groups_and_capitalize_1s(match): + return f"{match.group(1)}".upper() + "'s" + + +class FisherSwbdNormalizer: + """Note: the functions "normalize" and "keep" implement the logic + similar to Kaldi's data prep scripts for Fisher and SWBD: One + notable difference is that we don't change [cough], [lipsmack], + etc. to [noise]. We also don't implement all the edge cases of + normalization from Kaldi (hopefully won't make too much + difference). + """ + + def __init__(self) -> None: + self.remove_regexp_before = re.compile( + r"|".join( + [ + # special symbols + r"\[\[skip.*\]\]", + r"\[skip.*\]", + r"\[pause.*\]", + r"\[silence\]", + r"", + r"", + r"_1", + ] + ) + ) + + # tuples of (pattern, replacement) + # note: Kaldi replaces sighs, coughs, etc with [noise]. + # We don't do that here. + # We also lowercase the text as the first operation. + self.replace_regexps: Tuple[re.Pattern, str] = [ + # SWBD: + # [LAUGHTER-STORY] -> STORY + (re.compile(r"\[laughter-(.*?)\]"), r"\1"), + # [WEA[SONABLE]-/REASONABLE] + (re.compile(r"\[\S+/(\S+)\]"), r"\1"), + # -[ADV]AN[TAGE]- -> AN + (re.compile(r"-?\[.*?\](\w+)\[.*?\]-?"), r"\1-"), + # ABSOLUTE[LY]- -> ABSOLUTE- + (re.compile(r"(\w+)\[.*?\]-?"), r"\1-"), + # [AN]Y- -> Y- + # -[AN]Y- -> Y- + (re.compile(r"-?\[.*?\](\w+)-?"), r"\1-"), + # special tokens + (re.compile(r"\[laugh.*?\]"), r"[laughter]"), + (re.compile(r"\[sigh.*?\]"), r"[sigh]"), + (re.compile(r"\[cough.*?\]"), r"[cough]"), + (re.compile(r"\[mn.*?\]"), r"[vocalized-noise]"), + (re.compile(r"\[breath.*?\]"), r"[breath]"), + (re.compile(r"\[lipsmack.*?\]"), r"[lipsmack]"), + (re.compile(r"\[sneeze.*?\]"), r"[sneeze]"), + # abbreviations + ( + re.compile( + r"(\w)\.(\w)\.(\w)", + ), + insert_groups_and_capitalize_3, + ), + ( + re.compile( + r"(\w)\.(\w)", + ), + insert_groups_and_capitalize_2, + ), + ( + re.compile( + r"([a-h,j-z])\.", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"\._", + ), + r" ", + ), + ( + re.compile( + r"_(\w)", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"(\w)\.s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"([A-Z])\'s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"(\s\w\b|^\w\b)", + ), + insert_groups_and_capitalize_1, + ), + # words between apostrophes + (re.compile(r"'(\S*?)'"), r"\1"), + # dangling dashes (2 passes) + (re.compile(r"\s-\s"), r" "), + (re.compile(r"\s-\s"), r" "), + # special symbol with trailing dash + (re.compile(r"(\[.*?\])-"), r"\1"), + # Just remove all dashes + (re.compile(r"-"), r" "), + ] + + # unwanted symbols in the transcripts + self.remove_regexp_after = re.compile( + r"|".join( + [ + # remaining punctuation + r"\.", + r",", + r"\?", + r"{", + r"}", + r"~", + r"_\d", + ] + ) + ) + + self.post_fixes = [ + # Fix an issue related to [VOCALIZED NOISE] after dash removal + (re.compile(r"\[vocalized noise\]"), "[vocalized-noise]"), + ] + + self.whitespace_regexp = re.compile(r"\s+") + + def normalize(self, text: str) -> str: + text = text.lower() + + # first remove + text = self.remove_regexp_before.sub("", text) + + # then replace + for pattern, sub in self.replace_regexps: + text = pattern.sub(sub, text) + + # then remove + text = self.remove_regexp_after.sub("", text) + + # post fixes + for pattern, sub in self.post_fixes: + text = pattern.sub(sub, text) + + # then clean up whitespace + text = self.whitespace_regexp.sub(" ", text).strip() + + return text.upper() + + +def keep(sup: SupervisionSegment) -> bool: + if "((" in sup.text: + return False + + if " yes", + "[laugh] oh this is [laught] this is great [silence] yes", + "i don't kn- - know A.B.C's", + "so x. corp is good?", + "'absolutely yes", + "absolutely' yes", + "'absolutely' yes", + "'absolutely' yes 'aight", + "ABSOLUTE[LY]", + "ABSOLUTE[LY]-", + "[AN]Y", + "[AN]Y-", + "[ADV]AN[TAGE]", + "[ADV]AN[TAGE]-", + "-[ADV]AN[TAGE]", + "-[ADV]AN[TAGE]-", + "[WEA[SONABLE]-/REASONABLE]", + "[VOCALIZED-NOISE]-", + "~BULL", + "Frank E Peretti P E R E T T I", + "yeah yeah like Double O Seven he's supposed to do it", + "P A P E R paper", + "[noise] okay_1 um let me see [laughter] i've been sitting here awhile", + ]: + print(text) + print(normalizer.normalize(text)) + print() + + +if __name__ == "__main__": + test() + # exit() + main() diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py new file mode 100755 index 000000000..7316193d0 --- /dev/null +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +def remove_punctutation_and_other_symbol(text: str) -> str: + text = text.replace("--", " ") + text = text.replace("//", " ") + text = text.replace(".", " ") + text = text.replace("?", " ") + text = text.replace("~", " ") + text = text.replace(",", " ") + text = text.replace(";", " ") + text = text.replace("(", " ") + text = text.replace(")", " ") + text = text.replace("&", " ") + text = text.replace("%", " ") + text = text.replace("*", " ") + text = text.replace("{", " ") + text = text.replace("}", " ") + return text + + +def eval2000_clean_eform(text: str, eform_count) -> str: + string_to_remove = [] + piece = text.split('">') + for i in range(0, len(piece)): + s = piece[i] + '">' + res = re.search(r"", s) + if res is not None: + res_rm = res.group(1) + string_to_remove.append(res_rm) + for p in string_to_remove: + eform_string = p + text = text.replace(eform_string, " ") + eform_1 = " str: + text = text.replace("[/BABY CRYING]", " ") + text = text.replace("[/CHILD]", " ") + text = text.replace("[[DISTORTED]]", " ") + text = text.replace("[/DISTORTION]", " ") + text = text.replace("[[DRAWN OUT]]", " ") + text = text.replace("[[DRAWN-OUT]]", " ") + text = text.replace("[[FAINT]]", " ") + text = text.replace("[SMACK]", " ") + text = text.replace("[[MUMBLES]]", " ") + text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ") + text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]") + text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]") + text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " ") + text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ") + text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ") + text = text.replace("[[PROLONGED]]", " ") + text = text.replace("[/RUNNING WATER]", " ") + text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]") + text = text.replace("[[SINGING]]", " ") + text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]") + text = text.replace("[/STATIC]", " ") + text = text.replace("['THIRTIETH' DRAWN OUT]", " ") + text = text.replace("[/VOICES]", " ") + text = text.replace("[[WHISPERED]]", " ") + text = text.replace("[DISTORTION]", " ") + text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ") + text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]") + text = text.replace("[CHILD'S VOICE]", " ") + text = text.replace("[CHILD SCREAMS]", " ") + text = text.replace("[CHILD VOICE]", " ") + text = text.replace("[CHILD YELLING]", " ") + text = text.replace("[CHILD SCREAMING]", " ") + text = text.replace("[CHILD'S VOICE IN BACKGROUND]", " ") + text = text.replace("[CHANNEL NOISE]", " ") + text = text.replace("[CHANNEL ECHO]", " ") + text = text.replace("[ECHO FROM OTHER CHANNEL]", " ") + text = text.replace("[ECHO OF OTHER CHANNEL]", " ") + text = text.replace("[CLICK]", " ") + text = text.replace("[DISTORTED]", " ") + text = text.replace("[BABY CRYING]", " ") + text = text.replace("[METALLIC KNOCKING SOUND]", " ") + text = text.replace("[METALLIC SOUND]", " ") + + text = text.replace("[PHONE JIGGLING]", " ") + text = text.replace("[BACKGROUND SOUND]", " ") + text = text.replace("[BACKGROUND VOICE]", " ") + text = text.replace("[BACKGROUND VOICES]", " ") + text = text.replace("[BACKGROUND NOISE]", " ") + text = text.replace("[CAR HORNS IN BACKGROUND]", " ") + text = text.replace("[CAR HORNS]", " ") + text = text.replace("[CARNATING]", " ") + text = text.replace("[CRYING CHILD]", " ") + text = text.replace("[CHOPPING SOUND]", " ") + text = text.replace("[BANGING]", " ") + text = text.replace("[CLICKING NOISE]", " ") + text = text.replace("[CLATTERING]", " ") + text = text.replace("[ECHO]", " ") + text = text.replace("[KNOCK]", " ") + text = text.replace("[NOISE-GOOD]", "[NOISE]") + text = text.replace("[RIGHT]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[SQUEAK]", " ") + text = text.replace("[STATIC]", " ") + text = text.replace("[[SAYS WITH HIGH-PITCHED SCREAMING LAUGHTER]]", " ") + text = text.replace("[UH]", "UH") + text = text.replace("[MN]", "[VOCALIZED-NOISE]") + text = text.replace("[VOICES]", " ") + text = text.replace("[WATER RUNNING]", " ") + text = text.replace("[SOUND OF TWISTING PHONE CORD]", " ") + text = text.replace("[SOUND OF SOMETHING FALLING]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[NOISE OF MOVING PHONE]", " ") + text = text.replace("[SOUND OF RUNNING WATER]", " ") + text = text.replace("[CHANNEL]", " ") + text = text.replace("[SILENCE]", " ") + text = text.replace("-[W]HERE", "WHERE") + text = text.replace("Y[OU]I-", "YOU I") + text = text.replace("-[A]ND", "AND") + text = text.replace("JU[ST]", "JUST") + text = text.replace("{BREATH}", " ") + text = text.replace("{BREATHY}", " ") + text = text.replace("{CHANNEL NOISE}", " ") + text = text.replace("{CLEAR THROAT}", " ") + + text = text.replace("{CLEARING THROAT}", " ") + text = text.replace("{CLEARS THROAT}", " ") + text = text.replace("{COUGH}", " ") + text = text.replace("{DRAWN OUT}", " ") + text = text.replace("{EXHALATION}", " ") + text = text.replace("{EXHALE}", " ") + text = text.replace("{GASP}", " ") + text = text.replace("{HIGH SQUEAL}", " ") + text = text.replace("{INHALE}", " ") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LIPSMACK}", " ") + text = text.replace("{LIPSMACK}", " ") + + text = text.replace("{NOISE OF DISGUST}", " ") + text = text.replace("{SIGH}", " ") + text = text.replace("{SNIFF}", " ") + text = text.replace("{SNORT}", " ") + text = text.replace("{SHARP EXHALATION}", " ") + text = text.replace("{BREATH LAUGH}", " ") + + text = text.replace("[LAUGHTER]", " ") + text = text.replace("[NOISE]", " ") + text = text.replace("[VOCALIZED-NOISE]", " ") + text = text.replace("-", " ") + return text + + +def remove_languagetag(text: str) -> str: + langtag = re.findall(r"<(.*?)>", text) + for t in langtag: + text = text.replace(t, " ") + text = text.replace("<", " ") + text = text.replace(">", " ") + return text + + +def eval2000_normalizer(text: str) -> str: + # print("TEXT original: ",text) + eform_count = text.count("contraction e_form") + # print("eform corunt:", eform_count) + if eform_count > 0: + text = eval2000_clean_eform(text, eform_count) + text = text.upper() + text = remove_languagetag(text) + text = replace_silphone(text) + text = remove_punctutation_and_other_symbol(text) + text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ") + text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ") + spaces = re.findall(r"\s+", text) + for sp in spaces: + text = text.replace(sp, " ") + text = text.strip() + # text = self.whitespace_regexp.sub(" ", text).strip() + # print(text) + return text + + +def main(): + args = get_args() + sups = load_manifest_lazy_or_eager(args.input_sups) + assert isinstance(sups, SupervisionSet) + + tot, skip = 0, 0 + with SupervisionSet.open_writer(args.output_sups) as writer: + for sup in tqdm(sups, desc="Normalizing supervisions"): + tot += 1 + sup.text = eval2000_normalizer(sup.text) + if not sup.text: + skip += 1 + continue + writer.write(sup) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py new file mode 100755 index 000000000..d82a085ec --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang_bpe.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.utils import str2bool + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = [ + "", + "!SIL", + "", + args.oov, + "#0", + "", + "", + ] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py new file mode 120000 index 000000000..abc00d421 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/rt03_data_prep.sh b/egs/swbd/ASR/local/rt03_data_prep.sh new file mode 100755 index 000000000..8a5f64324 --- /dev/null +++ b/egs/swbd/ASR/local/rt03_data_prep.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash + +# RT-03 data preparation (conversational telephone speech part only) +# Adapted from Arnab Ghoshal's script for Hub-5 Eval 2000 by Peng Qi + +# To be run from one directory above this script. + +# Expects the standard directory layout for RT-03 + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/corpora/LDC/LDC2007S10" + echo "See comments in the script for more details" + exit 1 +fi + +sdir=$1 +[ ! -d $sdir/data/audio/eval03/english/cts ] && + echo Expecting directory $sdir/data/audio/eval03/english/cts to be present && exit 1 +[ ! -d $sdir/data/references/eval03/english/cts ] && + echo Expecting directory $tdir/data/references/eval03/english/cts to be present && exit 1 + +dir=data/local/rt03 +mkdir -p $dir + +rtroot=$sdir +tdir=$sdir/data/references/eval03/english/cts +sdir=$sdir/data/audio/eval03/english/cts + +find -L $sdir -iname '*.sph' | sort >$dir/sph.flist +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 + +awk -v sph2pipe=$sph2pipe '{ + printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); + printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 + +# Get segments file... +# segments file format is: utt-id side-id start-time end-time, e.g.: +# sw02001-A_000098-001156 sw02001-A 0.98 11.56 +#pem=$sdir/english/hub5e_00.pem +#[ ! -f $pem ] && echo "No such file $pem" && exit 1; +# pem file has lines like: +# en_4156 A unknown_speaker 301.85 302.48 + +#grep -v ';;' $pem \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + print utt,spk,$4,$5;}' | + sort -u >$dir/segments + +# stm file has lines like: +# en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER +# TODO(arnab): We should really be lowercasing this since the Edinburgh +# recipe uses lowercase. This is not used in the actual scoring. +#grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' | + sort >$dir/text.all + +# We'll use the stm file for sclite scoring. There seem to be various errors +# in the stm file that upset hubscr.pl, and we fix them here. +cat $tdir/*.stm | + sed -e 's:((:(:' -e 's:::g' -e 's:::g' | + grep -v inter_segment_gap | + awk '{ + printf $1; if ($1==";;") printf(" %s",$2); else printf(($2==1)?" A":" B"); for(n=3;n<=NF;n++) printf(" %s", $n); print ""; }' \ + >$dir/stm +#$tdir/reference/hub5e00.english.000405.stm > $dir/stm +cp $rtroot/data/trans_rules/en20030506.glm $dir/glm + +# next line uses command substitution +# Just checking that the segments are the same in pem vs. stm. +! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && + echo "Segments from pem file and stm file do not match." && exit 1 + +grep -v IGNORE_TIME_SEGMENT_ $dir/text.all >$dir/text + +# create an utt2spk file that assumes each conversation side is +# a separate speaker. +awk '{print $1,$2;}' $dir/segments >$dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk >$dir/spk2utt + +# cp $dir/segments $dir/segments.tmp +# awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ +# $dir/segments.tmp > $dir/segments + +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + >$dir/reco2file_and_channel || exit 1 + +./utils/fix_data_dir.sh $dir + +echo Data preparation and formatting completed for RT-03 +echo "(but not MFCC extraction)" diff --git a/egs/swbd/ASR/local/sort_lm_training_data.py b/egs/swbd/ASR/local/sort_lm_training_data.py new file mode 100755 index 000000000..bed3856e4 --- /dev/null +++ b/egs/swbd/ASR/local/sort_lm_training_data.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input the filename of LM training data +generated by ./local/prepare_lm_training_data.py and sorts +it by sentence length. + +Sentence length equals to the number of BPE tokens in a sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/lm_data.pt", + ) + + parser.add_argument( + "--out-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", + ) + + parser.add_argument( + "--out-statistics", + type=str, + help="Statistics about LM training data., data/bpe_500/statistics.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + in_lm_data = Path(args.in_lm_data) + out_lm_data = Path(args.out_lm_data) + assert in_lm_data.is_file(), f"{in_lm_data}" + if out_lm_data.is_file(): + logging.warning(f"{out_lm_data} exists - skipping") + return + data = torch.load(in_lm_data) + words2bpe = data["words"] + sentences = data["sentences"] + sentence_lengths = data["sentence_lengths"] + + num_sentences = sentences.dim0 + assert num_sentences == sentence_lengths.numel(), ( + num_sentences, + sentence_lengths.numel(), + ) + + indices = torch.argsort(sentence_lengths, descending=True) + + sorted_sentences = sentences[indices.to(torch.int32)] + sorted_sentence_lengths = sentence_lengths[indices] + + # Check that sentences are ordered by length + assert num_sentences == sorted_sentences.dim0, ( + num_sentences, + sorted_sentences.dim0, + ) + + cur = None + for i in range(num_sentences): + word_ids = sorted_sentences[i] + token_ids = words2bpe[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + if cur is not None: + assert cur >= token_ids.numel(), (cur, token_ids.numel()) + + cur = token_ids.numel() + assert cur == sorted_sentence_lengths[i] + + data["sentences"] = sorted_sentences + data["sentence_lengths"] = sorted_sentence_lengths + torch.save(data, args.out_lm_data) + logging.info(f"Saved to {args.out_lm_data}") + + statistics = Path(args.out_statistics) + + # Write statistics + num_words = sorted_sentences.numel() + num_tokens = sentence_lengths.sum().item() + max_sentence_length = sentence_lengths[indices[0]] + min_sentence_length = sentence_lengths[indices[-1]] + + step = 10 + hist, bins = np.histogram( + sentence_lengths.numpy(), + bins=np.arange(1, max_sentence_length + step, step), + ) + + histogram = np.stack((bins[:-1], hist)).transpose() + + with open(statistics, "w") as f: + f.write(f"num_sentences: {num_sentences}\n") + f.write(f"num_words: {num_words}\n") + f.write(f"num_tokens: {num_tokens}\n") + f.write(f"max_sentence_length: {max_sentence_length}\n") + f.write(f"min_sentence_length: {min_sentence_length}\n") + f.write("histogram:\n") + f.write(" bin count percent\n") + for row in histogram: + f.write( + f"{int(row[0]):>5} {int(row[1]):>5} " + f"{100.*row[1]/num_sentences:.3f}%\n" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/local/swbd1_data_prep.sh b/egs/swbd/ASR/local/swbd1_data_prep.sh new file mode 100755 index 000000000..159359491 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_data_prep.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash + +# Switchboard-1 training data preparation customized for Edinburgh +# Author: Arnab Ghoshal (Jan 2013) + +# To be run from one directory above this script. + +## The input is some directory containing the switchboard-1 release 2 +## corpus (LDC97S62). Note: we don't make many assumptions about how +## you unpacked this. We are just doing a "find" command to locate +## the .sph files. + +## The second input is optional, which should point to a directory containing +## Switchboard transcriptions/documentations (specifically, the conv.tab file). +## If specified, the script will try to use the actual speaker PINs provided +## with the corpus instead of the conversation side ID (Kaldi default). We +## will be using "find" to locate this file so we don't make any assumptions +## on the directory structure. (Peng Qi, Aug 2014) + +#check existing directories +if [ $# != 1 -a $# != 2 ]; then + echo "Usage: swbd1_data_prep.sh /path/to/SWBD [/path/to/SWBD_DOC]" + exit 1 +fi + +SWBD_DIR=$1 + +dir=data/local/train +mkdir -p $dir + +# Audio data directory check +if [ ! -d $SWBD_DIR ]; then + echo "Error: run.sh requires a directory argument" + exit 1 +fi + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 + +# Option A: SWBD dictionary file check +[ ! -f ./swb_ms98_transcriptions/sw-ms98-dict.text ] && + echo "SWBD dictionary file does not exist" && exit 1 + +# find sph audio files +find -L $SWBD_DIR -iname '*.sph' | sort >$dir/sph.flist + +n=$(cat $dir/sph.flist | wc -l) +[ $n -ne 2435 ] && [ $n -ne 2438 ] && + echo Warning: expected 2435 or 2438 data data files, found $n + +# (1a) Transcriptions preparation +# make basic transcription file (add segments info) +# **NOTE: In the default Kaldi recipe, everything is made uppercase, while we +# make everything lowercase here. This is because we will be using SRILM which +# can optionally make everything lowercase (but not uppercase) when mapping +# LM vocabs. +awk '{ +name=substr($1,1,6); gsub("^sw","sw0",name); side=substr($1,7,1); +stime=$2; etime=$3; +printf("%s-%s_%06.0f-%06.0f", +name, side, int(100*stime+0.5), int(100*etime+0.5)); +for(i=4;i<=NF;i++) printf(" %s", $i); printf "\n" +}' ./swb_ms98_transcriptions/*/*/*-trans.text >$dir/transcripts1.txt + +# test if trans. file is sorted +export LC_ALL=C +sort -c $dir/transcripts1.txt || exit 1 # check it's sorted. + +# Remove SILENCE, and . + +# Note: we have [NOISE], [VOCALIZED-NOISE], [LAUGHTER], [SILENCE]. +# removing [SILENCE], and the and markers that mark +# speech to somone; we will give phones to the other three (NSN, SPN, LAU). +# There will also be a silence phone, SIL. +# **NOTE: modified the pattern matches to make them case insensitive +cat $dir/transcripts1.txt | + perl -ane 's:\s\[SILENCE\](\s|$):$1:gi; + s///gi; + s///gi; + print;' | + awk '{if(NF > 1) { print; } } ' >$dir/transcripts2.txt + +# **NOTE: swbd1_map_words.pl has been modified to make the pattern matches +# case insensitive +local/swbd1_map_words.pl -f 2- $dir/transcripts2.txt >$dir/text + +# format acronyms in text +python3 local/map_acronyms_transcripts.py -i $dir/text -o $dir/text_map \ + -M data/local/dict_nosp/acronyms.map +mv $dir/text_map $dir/text + +# (1c) Make segment files from transcript +#segments file format is: utt-id side-id start-time end-time, e.g.: +#sw02001-A_000098-001156 sw02001-A 0.98 11.56 +awk '{ +segment=$1; +split(segment,S,"[_-]"); +side=S[2]; audioname=S[1]; startf=S[3]; endf=S[4]; +print segment " " audioname "-" side " " startf/100 " " endf/100 +}' <$dir/text >$dir/segments + +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp + +awk -v sph2pipe=$sph2pipe '{ +printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); +printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 + +# this file reco2file_and_channel maps recording-id (e.g. sw02001-A) +# to the file name sw02001 and the A, e.g. +# sw02001-A sw02001 A +# In this case it's trivial, but in other corpora the information might +# be less obvious. Later it will be needed for ctm scoring. +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + >$dir/reco2file_and_channel || exit 1 + +awk '{spk=substr($1,1,9); print $1 " " spk}' $dir/segments >$dir/utt2spk || + exit 1 +sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl >$dir/spk2utt || exit 1 + +echo Switchboard-1 data preparation succeeded. + +utils/fix_data_dir.sh data/local/train diff --git a/egs/swbd/ASR/local/swbd1_map_words.pl b/egs/swbd/ASR/local/swbd1_map_words.pl new file mode 100755 index 000000000..4fb8d4ffe --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_map_words.pl @@ -0,0 +1,52 @@ +#!/usr/bin/env perl + +# Modified from swbd_map_words.pl in Kaldi s5 recipe to make pattern +# matches case-insensitive --Arnab (Jan 2013) + +if ($ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesy (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } +} + + +while (<>) { + @A = split(" ", $_); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if ( (!defined $field_begin || $n >= $field_begin) + && (!defined $field_end || $n <= $field_end)) { + # e.g. [LAUGHTER-STORY] -> STORY; + $a =~ s:(|\-)^\[LAUGHTER-(.+)\](|\-)$:$1$2$3:i; + # $1 and $3 relate to preserving trailing "-" + $a =~ s:^\[(.+)/.+\](|\-)$:$1$2:; # e.g. [IT'N/ISN'T] -> IT'N ... note, + # 1st part may include partial-word stuff, which we process further below, + # e.g. [LEM[GUINI]-/LINGUINI] + # the (|\_) at the end is to accept and preserve trailing -'s. + $a =~ s:^(|\-)\[[^][]+\](.+)$:-$2:; # e.g. -[AN]Y , note \047 is quote; + # let the leading - be optional on input, as sometimes omitted. + $a =~ s:^(.+)\[[^][]+\](|\-)$:$1-:; # e.g. AB[SOLUTE]- -> AB-; + # let the trailing - be optional on input, as sometimes omitted. + $a =~ s:([^][]+)\[.+\]$:$1:; # e.g. EX[SPECIALLY]-/ESPECIALLY] -> EX- + # which is a mistake in the input. + $a =~ s:^\{(.+)\}$:$1:; # e.g. {YUPPIEDOM} -> YUPPIEDOM + $a =~ s:[A-Z]\[([^][])+\][A-Z]:$1-$3:i; # e.g. AMMU[N]IT- -> AMMU-IT- + $a =~ s:_\d$::; # e.g. THEM_1 -> THEM + } + $A[$n] = $a; + } + print join(" ", @A) . "\n"; +} diff --git a/egs/swbd/ASR/local/swbd1_prepare_dict.sh b/egs/swbd/ASR/local/swbd1_prepare_dict.sh new file mode 100755 index 000000000..eff5fb5f1 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_prepare_dict.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash + +# Formatting the Mississippi State dictionary for use in Edinburgh. Differs +# from the one in Kaldi s5 recipe in that it uses lower-case --Arnab (Jan 2013) + +# To be run from one directory above this script. + +#check existing directories +[ $# != 0 ] && echo "Usage: local/swbd1_data_prep.sh" && exit 1 + +srcdir=. # This is where we downloaded some stuff.. +dir=./data/local/dict_nosp +mkdir -p $dir +srcdict=$srcdir/swb_ms98_transcriptions/sw-ms98-dict.text + +# assume swbd_p1_data_prep.sh was done already. +[ ! -f "$srcdict" ] && echo "$0: No such file $srcdict" && exit 1 + +cp $srcdict $dir/lexicon0.txt || exit 1 +chmod a+w $dir/lexicon0.txt +patch 0' | sort >$dir/lexicon1.txt || exit 1 + +cat $dir/lexicon1.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}' | + grep -v sil >$dir/nonsilence_phones.txt || exit 1 + +( + echo sil + echo spn + echo nsn + echo lau +) >$dir/silence_phones.txt + +echo sil >$dir/optional_silence.txt + +# No "extra questions" in the input to this setup, as we don't +# have stress or tone. +echo -n >$dir/extra_questions.txt + +cp local/MSU_single_letter.txt $dir/ +# Add to the lexicon the silences, noises etc. +# Add single letter lexicon +# The original swbd lexicon does not have precise single letter lexicion +# e.g. it does not have entry of W +( + echo '!SIL SIL' + echo '[VOCALIZED-NOISE] spn' + echo '[NOISE] nsn' + echo '[LAUGHTER] lau' + echo ' spn' +) | + cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt >$dir/lexicon2.txt || exit 1 + +# Map the words in the lexicon. That is-- for each word in the lexicon, we map it +# to a new written form. The transformations we do are: +# remove laughter markings, e.g. +# [LAUGHTER-STORY] -> STORY +# Remove partial-words, e.g. +# -[40]1K W AH N K EY +# becomes -1K +# and +# -[AN]Y IY +# becomes +# -Y +# -[A]B[OUT]- B +# becomes +# -B- +# Also, curly braces, which appear to be used for "nonstandard" +# words or non-words, are removed, e.g. +# {WOLMANIZED} W OW L M AX N AY Z D +# -> WOLMANIZED +# Also, mispronounced words, e.g. +# [YEAM/YEAH] Y AE M +# are changed to just e.g. YEAM, i.e. the orthography +# of the mispronounced version. +# Note-- this is only really to be used in training. The main practical +# reason is to avoid having tons of disambiguation symbols, which +# we otherwise would get because there are many partial words with +# the same phone sequences (most problematic: S). +# Also, map +# THEM_1 EH M -> THEM +# so that multiple pronunciations just have alternate entries +# in the lexicon. + +local/swbd1_map_words.pl -f 1 $dir/lexicon2.txt | sort -u \ + >$dir/lexicon3.txt || exit 1 + +python3 local/format_acronyms_dict.py -i $dir/lexicon3.txt -o $dir/lexicon4.txt \ + -L $dir/MSU_single_letter.txt -M $dir/acronyms_raw.map +cat $dir/acronyms_raw.map | sort -u >$dir/acronyms.map + +(echo 'i ay') | cat - $dir/lexicon4.txt | tr '[A-Z]' '[a-z]' | sort -u >$dir/lexicon5.txt + +pushd $dir >&/dev/null +ln -sf lexicon5.txt lexicon.txt # This is the final lexicon. +popd >&/dev/null +rm $dir/lexiconp.txt 2>/dev/null +echo Prepared input dictionary and phone-sets for Switchboard phase 1. diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..9b4e28635 --- /dev/null +++ b/egs/swbd/ASR/local/train_bpe_model.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + user_defined_symbols += ["[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]"] + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/swbd/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh new file mode 100755 index 000000000..47d12613b --- /dev/null +++ b/egs/swbd/ASR/prepare.sh @@ -0,0 +1,463 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. Most of them can't be downloaded automatically +# as they are not publically available and require a license purchased +# from the LDC. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=./download +# swbd1_dir="/export/corpora3/LDC/LDC97S62" +swbd1_dir=./download/LDC97S62/ + +# eval2000_dir contains the following files and directories +# downloaded from LDC website: +# - LDC2002S09 +# - hub5e_00 +# - LDC2002T43 +# - reference +eval2000_dir="/export/corpora2/LDC/eval2000" + +rt03_dir="/export/corpora/LDC/LDC2007S10" +fisher_dir="/export/corpora3/LDC/LDC2004T19" + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "swbd1_dir: $swbd1_dir" +log "eval2000_dir: $eval2000_dir" +log "rt03_dir: $rt03_dir" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare SwitchBoard manifest" + # We assume that you have downloaded the SwitchBoard corpus + # to respective dirs + mkdir -p data/manifests + if [ ! -e data/manifests/.swbd.done ]; then + lhotse prepare switchboard --absolute-paths 1 --omit-silence $swbd1_dir data/manifests/swbd + ./local/normalize_and_filter_supervisions.py \ + data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all.jsonl.gz data/manifests/swbd/swbd_supervisions_orig.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/swbd/swbd_recordings_all.jsonl.gz \ + -s data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all.jsonl.gz + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/swbd/swbd_train_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz + + num_splits=16 + mkdir -p data/manifests/swbd_split${num_splits} + lhotse split ${num_splits} \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz \ + data/manifests/swbd_split${num_splits} + + lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 + ./local/normalize_eval2000.py \ + data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ + data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/eval2000/eval2000_recordings_all.jsonl.gz \ + -s data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz + + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz + + sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ + $eval2000_dir/LDC2002T43/reference/hub5e00.english.000405.stm > data/manifests/eval2000/stm + cp $eval2000_dir/LDC2002T43/reference/en20000405_hub5.glm $dir/glm + + # ./local/rt03_data_prep.sh $rt03_dir + + # normalize eval2000 and rt03 texts by + # 1) convert upper to lower + # 2) remove tags (%AH) (%HESITATION) (%UH) + # 3) remove + # 4) remove "(" or ")" + # for x in rt03; do + # cp data/local/${x}/text data/local/${x}/text.org + # paste -d "" \ + # <(cut -f 1 -d" " data/local/${x}/text.org) \ + # <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") | + # sed -e 's/\s\+/ /g' >data/local/${x}/text + # rm data/local/${x}/text.org + # done + + # lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests + + touch data/manifests/.swbd.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3 I: Compute fbank for SwitchBoard" + if [ ! -e data/fbank/.swbd.done ]; then + mkdir -p data/fbank/swbd_split${num_splits}/ + for index in $(seq 1 16); do + ./local/compute_fbank_swbd.py --split-index ${index} & + done + wait + pieces=$(find data/fbank/swbd_split${num_splits} -name "swbd_cuts_all.*.jsonl.gz") + lhotse combine $pieces data/fbank/swbd_cuts_all.jsonl.gz + touch data/fbank/.swbd.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3 II: Compute fbank for eval2000" + if [ ! -e data/fbank/.eval2000.done ]; then + mkdir -p data/fbank/eval2000/ + ./local/compute_fbank_eval2000.py + touch data/fbank/.eval2000.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + if ! which jq; then + echo "This script is intended to be used with jq but you have not installed jq + Note: in Linux, you can install jq with the following command: + 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + 2. chmod +x ./jq + 3. cp jq /usr/bin" && exit 1 + fi + if [ ! -f $lang_dir/text ] || [ ! -s $lang_dir/text ]; then + log "Prepare text." + gunzip -c data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $lang_dir/text + fi + + log "Prepare dict" + ./local/swbd1_prepare_dict.sh + cut -f 2- -d" " $lang_dir/text >${lang_dir}/input.txt + # [noise] nsn + # !sil sil + # spn + cat data/local/dict_nosp/lexicon.txt | sed 's/-//g' | sed 's/\[vocalizednoise\]/\[vocalized-noise\]/g' | + sort | uniq >$lang_dir/lexicon_lower.txt + + cat $lang_dir/lexicon_lower.txt | tr a-z A-Z > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + + cat data/lang_phone/text | cut -d " " -f 2- >$lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare bigram token-level P for MMI training" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + >$lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa >$lang_dir/P.fst.txt + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" + lang_dir=data/lang_phone + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/3-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.arpa >data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + ./shared/make_kn_lm.py \ + -ngram-order 4 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/4-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + data/lm/4-gram.arpa >data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Generate LM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/train.txt ]; then + tail -n 250000 data/lang_phone/input.txt > $out_dir/train.txt + fi + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data data/lang_phone/input.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Generate LM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + head -n 14332 data/lang_phone/input.txt > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Generate LM test data" + testsets=(eval2000) + + for testset in ${testsets[@]}; do + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/${testset}.txt ]; then + gunzip -c data/manifests/${testset}/eval2000_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $out_dir/${testset}.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/${testset}.txt \ + --lm-archive $out_dir/lm_data-${testset}.pt + done + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Sort LM training data" + testsets=(eval2000) + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + for testset in ${testsets[@]}; do + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-${testset}.pt \ + --out-lm-data $out_dir/sorted_lm_data-${testset}.pt \ + --out-statistics $out_dir/statistics-test-${testset}.txt + done + done +fi diff --git a/egs/swbd/ASR/shared b/egs/swbd/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/swbd/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/swbd/ASR/utils/filter_scp.pl b/egs/swbd/ASR/utils/filter_scp.pl new file mode 100755 index 000000000..b76d37f41 --- /dev/null +++ b/egs/swbd/ASR/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/swbd/ASR/utils/fix_data_dir.sh b/egs/swbd/ASR/utils/fix_data_dir.sh new file mode 100755 index 000000000..ca0972ca8 --- /dev/null +++ b/egs/swbd/ASR/utils/fix_data_dir.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +# This script makes sure that only the segments present in +# all of "feats.scp", "wav.scp" [if present], segments [if present] +# text, and utt2spk are present in any of them. +# It puts the original contents of data-dir into +# data-dir/.backup + +cmd="$@" + +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: utils/data/fix_data_dir.sh " + echo "e.g.: utils/data/fix_data_dir.sh data/train" + echo "This script helps ensure that the various files in a data directory" + echo "are correctly sorted and filtered, for example removing utterances" + echo "that have no features (if feats.scp is present)" + exit 1 +fi + +data=$1 + +if [ -f $data/images.scp ]; then + image/fix_data_dir.sh $cmd + exit $? +fi + +mkdir -p $data/.backup + +[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; + +[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; + +set -e -o pipefail -u + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted { + file=$1 + sort -k1,1 -u <$file >$file.tmp + if ! cmp -s $file $file.tmp; then + echo "$0: file $1 is not in sorted order or not unique, sorting it" + mv $file.tmp $file + else + rm $file.tmp + fi +} + +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ + reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + check_sorted $data/$x + fi +done + + +function filter_file { + filter=$1 + file_to_filter=$2 + cp $file_to_filter ${file_to_filter}.tmp + utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter + if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then + length1=$(cat ${file_to_filter}.tmp | wc -l) + length2=$(cat ${file_to_filter} | wc -l) + if [ $length1 -ne $length2 ]; then + echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." + fi + fi + rm $file_to_filter.tmp +} + +function filter_recordings { + # We call this once before the stage when we filter on utterance-id, and once + # after. + + if [ -f $data/segments ]; then + # We have a segments file -> we need to filter this and the file wav.scp, and + # reco2file_and_utt, if it exists, to make sure they have the same list of + # recording-ids. + + if [ ! -f $data/wav.scp ]; then + echo "$0: $data/segments exists but not $data/wav.scp" + exit 1; + fi + awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings + n1=$(cat $tmpdir/recordings | wc -l) + [ ! -s $tmpdir/recordings ] && \ + echo "Empty list of recordings (bad file $data/segments)?" && exit 1; + utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp + mv $tmpdir/recordings.tmp $tmpdir/recordings + + + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + filter_file $tmpdir/recordings $data/segments + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + rm $data/segments.tmp + + filter_file $tmpdir/recordings $data/wav.scp + [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel + [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur + true + fi +} + +function filter_speakers { + # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... + utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + for s in cmvn.scp spk2gender; do + f=$data/$s + if [ -f $f ]; then + filter_file $f $tmpdir/speakers + fi + done + + filter_file $tmpdir/speakers $data/spk2utt + utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + + for s in cmvn.scp spk2gender $spk_extra_files; do + f=$data/$s + if [ -f $f ]; then + filter_file $tmpdir/speakers $f + fi + done +} + +function filter_utts { + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + + ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; + + ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; + + ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ + echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; + + if [ -f $data/utt2uniq ]; then + ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ + echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; + fi + + maybe_wav= + maybe_reco2dur= + [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. + [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts + for x in feats.scp text segments utt2lang $maybe_wav; do + if [ -f $data/$x ]; then + utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp + mv $tmpdir/utts.tmp $tmpdir/utts + fi + done + [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ + rm $tmpdir/utts && exit 1; + + + if [ -f $data/utt2spk ]; then + new_nutts=$(cat $tmpdir/utts | wc -l) + old_nutts=$(cat $data/utt2spk | wc -l) + if [ $new_nutts -ne $old_nutts ]; then + echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" + else + echo "fix_data_dir.sh: kept all $old_nutts utterances." + fi + fi + + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then + utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x + fi + fi + done + +} + +filter_recordings +filter_speakers +filter_utts +filter_speakers +filter_recordings + +utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + +echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/swbd/ASR/utils/parse_options.sh b/egs/swbd/ASR/utils/parse_options.sh new file mode 100755 index 000000000..34476fdb3 --- /dev/null +++ b/egs/swbd/ASR/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl new file mode 100755 index 000000000..23992f25d --- /dev/null +++ b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl @@ -0,0 +1,27 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +while(<>){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl new file mode 100755 index 000000000..6e0e438ca --- /dev/null +++ b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# converts an utt2spk file to a spk2utt file. +# Takes input from the stdin or from a file argument; +# output goes to the standard out. + +if ( @ARGV > 1 ) { + die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; +} + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + push (@{$spk_hash{$s}}, "$u"); +} +foreach $s (@spklist) { + $l = join(' ',@{$spk_hash{$s}}); + print "$s $l\n"; +} From ce08230adea2b2c5c45fd3e028cbfd914ffd8ec2 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 7 Oct 2023 11:57:30 +0800 Subject: [PATCH 002/216] Update README.md (#1293) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 523203aa4..c89e7b9aa 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ We provide the following recipes: - [yesno][yesno] - [LibriSpeech][librispeech] - [GigaSpeech][gigaspeech] + - [AMI][ami] - [Aishell][aishell] - [Aishell2][aishell2] - [Aishell4][aishell4] @@ -37,6 +38,7 @@ We provide the following recipes: - [Aidatatang_200zh][aidatatang_200zh] - [WenetSpeech][wenetspeech] - [Alimeeting][alimeeting] + - [Switchboard][swbd] - [TAL_CSASR][tal_csasr] ### yesno @@ -393,4 +395,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [wenetspeech]: egs/wenetspeech/ASR [alimeeting]: egs/alimeeting/ASR [tal_csasr]: egs/tal_csasr/ASR +[ami]: egs/ami +[swbd]: egs/swbd/ASR [k2]: https://github.com/k2-fsa/k2 From fefffc02f68645dbbb2c0a54919c75f37da5dd4f Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 9 Oct 2023 17:39:23 +0800 Subject: [PATCH 003/216] Update optim.py (#1292) --- egs/librispeech/ASR/zipformer/optim.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c9b76526c..8ee2b0eb4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -491,6 +491,12 @@ class ScaledAdam(BatchedOptimizer): if self.show_dominant_parameters: assert p.shape[0] == len(param_names) self._show_gradient_dominating_parameter(tuples, tot_sumsq) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans == 0.0: + for p, state, param_names in tuples: + p.grad.zero_() # get rid of infinity() + return ans def _show_gradient_dominating_parameter( @@ -573,7 +579,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad if clipping_scale != 1.0: - grad = grad * clipping_scale + grad *= clipping_scale step = state["step"] delta = state["delta"] From 9af144c26b91065a119d4e67c03004974462d24d Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 9 Oct 2023 23:15:22 +0800 Subject: [PATCH 004/216] Zipformer update result (#1296) * update Zipformer results --- README.md | 6 +++--- egs/librispeech/ASR/RESULTS.md | 34 +++++++++++++++++++++------------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index c89e7b9aa..da446109d 100644 --- a/README.md +++ b/README.md @@ -120,9 +120,9 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles | Encoder | Params | test-clean | test-other | |-----------------|--------|------------|------------| -| zipformer | 65.5M | 2.21 | 4.91 | -| zipformer-small | 23.2M | 2.46 | 5.83 | -| zipformer-large | 148.4M | 2.11 | 4.77 | +| zipformer | 65.5M | 2.21 | 4.79 | +| zipformer-small | 23.2M | 2.42 | 5.73 | +| zipformer-large | 148.4M | 2.06 | 4.63 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index b945f43fd..fc7fcdc26 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -75,7 +75,7 @@ See for more details. ##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -90,18 +90,20 @@ You can use to deploy it. | greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 | | modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 | | fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 | +| greedy_search | 2.22 | 4.87 | --epoch 50 --avg 25 | +| modified_beam_search | 2.21 | 4.79 | --epoch 50 --avg 25 | +| fast_beam_search | 2.21 | 4.82 | --epoch 50 --avg 25 | | modified_beam_search_shallow_fusion | 2.01 | 4.37 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.3 | | modified_beam_search_LODR | 1.94 | 4.17 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.52 --LODR-scale -0.26 | | modified_beam_search_rescore | 2.04 | 4.39 | --epoch 40 --avg 16 --beam-size 12 | | modified_beam_search_rescore_LODR | 2.01 | 4.33 | --epoch 40 --avg 16 --beam-size 12 | - The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1,2,3" ./zipformer/train.py \ --world-size 4 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ @@ -115,8 +117,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ + --epoch 50 \ + --avg 25 \ --use-averaged-model 1 \ --exp-dir ./zipformer/exp \ --max-duration 600 \ @@ -129,7 +131,7 @@ To decode with external language models, please refer to the documentation [here ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -144,13 +146,16 @@ You can use to deploy it. | greedy_search | 2.49 | 5.91 | --epoch 40 --avg 13 | | modified_beam_search | 2.46 | 5.83 | --epoch 40 --avg 13 | | fast_beam_search | 2.46 | 5.87 | --epoch 40 --avg 13 | +| greedy_search | 2.46 | 5.86 | --epoch 50 --avg 23 | +| modified_beam_search | 2.42 | 5.73 | --epoch 50 --avg 23 | +| fast_beam_search | 2.46 | 5.78 | --epoch 50 --avg 23 | The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1" ./zipformer/train.py \ --world-size 2 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp-small \ @@ -169,8 +174,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 40 \ - --avg 13 \ + --epoch 50 \ + --avg 23 \ --exp-dir zipformer/exp-small \ --max-duration 600 \ --causal 0 \ @@ -185,7 +190,7 @@ done ##### large-scaled model, number of model parameters: 148439574, i.e., 148.4 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -200,13 +205,16 @@ You can use to deploy it. | greedy_search | 2.12 | 4.8 | --epoch 40 --avg 13 | | modified_beam_search | 2.11 | 4.7 | --epoch 40 --avg 13 | | fast_beam_search | 2.13 | 4.78 | --epoch 40 --avg 13 | +| greedy_search | 2.08 | 4.69 | --epoch 50 --avg 30 | +| modified_beam_search | 2.06 | 4.63 | --epoch 50 --avg 30 | +| fast_beam_search | 2.09 | 4.68 | --epoch 50 --avg 30 | The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1,2,3" ./zipformer/train.py \ --world-size 4 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp-large \ @@ -224,8 +232,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 40 \ - --avg 16 \ + --epoch 50 \ + --avg 30 \ --exp-dir zipformer/exp-large \ --max-duration 600 \ --causal 0 \ From 0d09a449303eb899eb729238436c3766a0a328b5 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 11 Oct 2023 10:06:00 +0800 Subject: [PATCH 005/216] Update train.py (#1299) --- egs/aishell/ASR/pruned_transducer_stateless7/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index 11671db92..9d9dd4288 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -703,7 +703,7 @@ def compute_loss( if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training From 103d617380c5a49599fcc7fd713c69d861989453 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 11 Oct 2023 11:04:20 +0800 Subject: [PATCH 006/216] bug fixes (#1301) --- egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py | 6 +++--- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py index 1b6991bcd..2f8e658c5 100644 --- a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py +++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py @@ -32,7 +32,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples @@ -230,8 +230,8 @@ class LibriSpeechAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index 59d73c660..aeeb2ef78 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -32,7 +32,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -302,8 +302,8 @@ class SwitchBoardAsrDataModule: buffer_size=50000, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, From cb874e99055c3d62d199ce8d296bb118f3d8aa23 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Oct 2023 12:20:12 +0800 Subject: [PATCH 007/216] add export-onnx.py for stateless8 (#1302) * add export-onnx.py for stateless8 * use tokens.txt to replace bpe.model --- .../export-onnx.py | 604 ++++++++++++++++++ 1 file changed, 604 insertions(+) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py new file mode 100755 index 000000000..3fef231a1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/WeijiZhuang/icefall-asr-librispeech-pruned-transducer-stateless8-2022-12-02 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/WeijiZhuang/icefall-asr-librispeech-pruned-transducer-stateless8-2022-12-02 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless8/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import num_tokens, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless7", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + # is defined in local/train_bpe_model.py + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() From 16a2748d6cc0eed7f08d034e835a4deef8565d73 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 11 Oct 2023 14:56:41 +0800 Subject: [PATCH 008/216] PromptASR for contextualized ASR with controllable style (#1250) * Add PromptASR with BERT as text encoder * Support using word-list based content prompts for context biasing * Upload the pretrained models to huggingface * Add usage example --- egs/libriheavy/ASR/RESULTS.md | 205 ++ egs/libriheavy/ASR/prepare_prompt_asr.sh | 36 + egs/libriheavy/ASR/shared | 1 + .../ASR/zipformer_prompt_asr/__init__.py | 0 .../zipformer_prompt_asr/asr_datamodule.py | 520 ++++ .../ASR/zipformer_prompt_asr/beam_search.py | 1 + .../ASR/zipformer_prompt_asr/dataset.py | 586 +++++ .../zipformer_prompt_asr/decode_baseline.py | 791 ++++++ .../ASR/zipformer_prompt_asr/decode_bert.py | 1025 ++++++++ ...decode_bert_with_style_save_decoding_mp.py | 963 +++++++ .../ASR/zipformer_prompt_asr/decoder.py | 130 + .../zipformer_prompt_asr/encoder_interface.py | 43 + .../zipformer_prompt_asr/export_PromptASR.py | 255 ++ .../ASR/zipformer_prompt_asr/joiner.py | 86 + .../ls_text_normalization.py | 153 ++ .../zipformer_prompt_asr/model_baseline.py | 262 ++ .../zipformer_prompt_asr/model_with_BERT.py | 392 +++ .../ASR/zipformer_prompt_asr/optim.py | 1168 +++++++++ .../ASR/zipformer_prompt_asr/pretrained.py | 359 +++ .../ASR/zipformer_prompt_asr/scaling.py | 1872 +++++++++++++ .../ASR/zipformer_prompt_asr/subsampling.py | 276 ++ .../ASR/zipformer_prompt_asr/test_model.py | 119 + .../text_normalization.py | 101 + .../zipformer_prompt_asr/train_baseline.py | 1390 ++++++++++ .../train_bert_encoder.py | 1798 +++++++++++++ .../zipformer_prompt_asr/transcribe_bert.py | 515 ++++ .../ASR/zipformer_prompt_asr/utils.py | 439 ++++ .../ASR/zipformer_prompt_asr/zipformer.py | 2310 +++++++++++++++++ icefall/utils.py | 32 +- 29 files changed, 15825 insertions(+), 3 deletions(-) create mode 100644 egs/libriheavy/ASR/RESULTS.md create mode 100755 egs/libriheavy/ASR/prepare_prompt_asr.sh create mode 120000 egs/libriheavy/ASR/shared create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py create mode 120000 egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py create mode 100755 egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py create mode 100755 egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/optim.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py create mode 100755 egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py create mode 100755 egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/utils.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py diff --git a/egs/libriheavy/ASR/RESULTS.md b/egs/libriheavy/ASR/RESULTS.md new file mode 100644 index 000000000..4fbedad98 --- /dev/null +++ b/egs/libriheavy/ASR/RESULTS.md @@ -0,0 +1,205 @@ +## Results + +### Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) + +#### [zipformer_prompt_asr](./zipformer_prompt_asr) + +See for commit history and +our paper for more details. + + + +##### Training on the medium subset, with content & style prompt, **no** context list + +You can find a pre-trained model, training logs, decoding logs, and decoding results at: + +The training command is: + +```bash +causal=0 +subset=medium +memory_dropout_rate=0.05 +text_encoder_type=BERT + +python ./zipformer_prompt_asr/train_bert_encoder.py \ + --world-size 4 \ + --start-epoch 1 \ + --num-epochs 60 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --use-fp16 True \ + --memory-dropout-rate $memory_dropout_rate \ + --causal $causal \ + --subset $subset \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --text-encoder-type $text_encoder_type \ + --text-encoder-dim 768 \ + --use-context-list 0 \ + --top-k $top_k \ + --use-style-prompt 1 +``` + +The decoding results using utterance-level context (epoch-60-avg-10): + +| decoding method | lh-test-clean | lh-test-other | comment | +|----------------------|---------------|---------------|---------------------| +| modified_beam_search | 3.13 | 6.78 | --use-pre-text False --use-style-prompt False | +| modified_beam_search | 2.86 | 5.93 | --pre-text-transform upper-no-punc --style-text-transform upper-no-punc | +| modified_beam_search | 2.6 | 5.5 | --pre-text-transform mixed-punc --style-text-transform mixed-punc | + + +The decoding command is: + +```bash +for style in mixed-punc upper-no-punc; do + python ./zipformer_prompt_asr/decode_bert.py \ + --epoch 60 \ + --avg 10 \ + --use-averaged-model True \ + --post-normalization True \ + --causal False \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --memory-layer 0 \ + --use-ls-test-set False \ + --use-ls-context-list False \ + --max-prompt-lens 1000 \ + --use-pre-text True \ + --use-style-prompt True \ + --style-text-transform $style \ + --pre-text-transform $style \ + --compute-CER 0 +done +``` + +##### Training on the medium subset, with content & style prompt, **with** context list + +You can find a pre-trained model, training logs, decoding logs, and decoding results at: + +This model is trained with an extra type of content prompt (context words), thus it does better +on **word-level** context biasing. Note that to train this model, please first run `prepare_prompt_asr.sh` +to prepare a manifest containing context words. + +The training command is: + +```bash + +causal=0 +subset=medium +memory_dropout_rate=0.05 +text_encoder_type=BERT +use_context_list=True + +# prepare the required data for context biasing +./prepare_prompt_asr.sh --stage 0 --stop_stage 1 + +python ./zipformer_prompt_asr/train_bert_encoder.py \ + --world-size 4 \ + --start-epoch 1 \ + --num-epochs 50 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --use-fp16 True \ + --memory-dropout-rate $memory_dropout_rate \ + --causal $causal \ + --subset $subset \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --text-encoder-type $text_encoder_type \ + --text-encoder-dim 768 \ + --use-context-list $use_context_list \ + --top-k 10000 \ + --use-style-prompt 1 +``` + +*Utterance-level biasing:* + +| decoding method | lh-test-clean | lh-test-other | comment | +|----------------------|---------------|---------------|---------------------| +| modified_beam_search | 3.17 | 6.72 | --use-pre-text 0 --use-style-prompt 0 | +| modified_beam_search | 2.91 | 6.24 | --pre-text-transform upper-no-punc --style-text-transform upper-no-punc | +| modified_beam_search | 2.72 | 5.72 | --pre-text-transform mixed-punc --style-text-transform mixed-punc | + + +The decoding command for the table above is: + +```bash +for style in mixed-punc upper-no-punc; do + python ./zipformer_prompt_asr/decode_bert.py \ + --epoch 50 \ + --avg 10 \ + --use-averaged-model True \ + --post-normalization True \ + --causal False \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --memory-layer 0 \ + --use-ls-test-set False \ + --use-ls-context-list False \ + --max-prompt-lens 1000 \ + --use-pre-text True \ + --use-style-prompt True \ + --style-text-transform $style \ + --pre-text-transform $style \ + --compute-CER 0 +done +``` + +*Word-level biasing:* + +The results are reported on LibriSpeech test-sets using the biasing list provided from . +You need to set `--use-ls-test-set True` so that the LibriSpeech test sets are used. + +| decoding method | ls-test-clean | ls-test-other | comment | +|----------------------|---------------|---------------|---------------------| +| modified_beam_search | 2.4 | 5.08 | --use-pre-text 0 --use-style-prompt 0 | +| modified_beam_search | 2.14 | 4.62 | --use-ls-context-list 1 --pre-text-transform mixed-punc --style-text-transform mixed-punc --ls-distractors 0 | +| modified_beam_search | 2.14 | 4.64 | --use-ls-context-list 1 --pre-text-transform mixed-punc --style-text-transform mixed-punc --ls-distractors 100 | + +The decoding command is for the table above is: + +```bash +use_ls_test_set=1 +use_ls_context_list=1 + +for ls_distractors in 0 100; do + python ./zipformer_prompt_asr/decode_bert.py \ + --epoch 50 \ + --avg 10 \ + --use-averaged-model True \ + --post-normalization True \ + --causal False \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --memory-layer 0 \ + --use-ls-test-set $use_ls_test_setse \ + --use-ls-context-list $use_ls_context_list \ + --ls-distractors $ls_distractors \ + --max-prompt-lens 1000 \ + --use-pre-text True \ + --use-style-prompt True \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc \ + --compute-CER 0 +done + +``` diff --git a/egs/libriheavy/ASR/prepare_prompt_asr.sh b/egs/libriheavy/ASR/prepare_prompt_asr.sh new file mode 100755 index 000000000..b931cea26 --- /dev/null +++ b/egs/libriheavy/ASR/prepare_prompt_asr.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +stage=-1 +stop_stage=100 +manifest_dir=data/fbank +subset=medium +topk=10000 + +. shared/parse_options.sh || exit 1 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download the meta biasing list for LibriSpeech" + mkdir -p data/context_biasing + cd data/context_biasing + git clone https://github.com/facebookresearch/fbai-speech.git + cd ../.. +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Add rare-words for context biasing to the manifest" + python zipformer_prompt_asr/utils.py \ + --manifest-dir $manifest_dir \ + --subset $subset \ + --top-k $topk + +fi diff --git a/egs/libriheavy/ASR/shared b/egs/libriheavy/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/libriheavy/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py b/egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py new file mode 100644 index 000000000..690003377 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -0,0 +1,520 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import torch +from dataset import PromptASRDataset +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # SingleCutSampler, + CutConcatenate, + CutMix, + DynamicBucketingSampler, + ExtraPadding, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriHeavyAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + if args.use_context_list: + assert args.rare_word_file is not None + with open(args.rare_word_file, "r") as f: + self.rare_word_list = ( + f.read().lower().split() + ) # Use lower-cased for easier style transform + else: + self.rare_word_list = None + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", + ) + + # Libriheavy specific arguments + group.add_argument( + "--subset", + type=str, + default="small", + help="Select the Libriheavy subset (small|medium|large)", + ) + + group.add_argument( + "--use-context-list", + type=str2bool, + default=False, + help="Use the context list of libri heavy", + ) + + group.add_argument( + "--top-k", + type=int, + default=10000, + help="""The top-k words are identified as common words, + the rest as rare words""", + ) + + group.add_argument( + "--with-decoding", + type=str2bool, + default=False, + help="If the texts field contain decoding", + ) + + group.add_argument( + "--random-left-padding", + type=str2bool, + ) + + group.add_argument( + "--rare-word-file", + type=str, + ) + + group.add_argument( + "--long-audio-cuts", + type=str, + default="data/manifest_npr/npr1_cuts_all_guids_0.jsonl.gz", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + text_sampling_func: Callable[[List[str]], str] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = PromptASRDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = PromptASRDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + raise NotImplementedError( + "SingleCutSampler is no longer supported by lhotse" + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + text_sampling_func: Callable[[List[str]], str] = None, + ) -> DataLoader: + transforms = [] + if self.args.random_left_padding: + logging.info("Enable random left padding") + transforms.append( + ExtraPadding(extra_frames=16, randomized=True, direction="left") + ) + + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = PromptASRDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, + ) + else: + validate = PromptASRDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get {self.args.subset} cuts") + + if self.args.use_context_list: + path = ( + self.args.manifest_dir + / f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz" + ) + elif self.args.with_decoding: + path = ( + self.args.manifest_dir + / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz" + ) + else: + path = ( + self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz" + ) + + logging.info(f"Loading manifest from {path}.") + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" + ) + return cuts_valid + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test-clean_official.jsonl.gz" + ) + return cuts_valid + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test-other_official.jsonl.gz" + ) + return cuts_valid + + @lru_cache() + def librispeech_test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def librispeech_test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def long_audio_cuts(self) -> CutSet: + logging.info("About to get long audio cuts") + cuts = load_manifest_lazy( + self.args.long_audio_cuts, + ) + return cuts + + @lru_cache() + def test_dev_cuts(self) -> CutSet: + logging.info("About to get test dev cuts") + cuts = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz" + ) + return cuts diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py b/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py new file mode 100644 index 000000000..e0bf8f73d --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py @@ -0,0 +1,586 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset import K2SpeechRecognitionDataset +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from text_normalization import ( + lower_all_char, + lower_only_alpha, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from torch.utils.data.dataloader import DataLoader, default_collate + + +class PromptASRDataset(torch.utils.data.Dataset): + """This is a dataset for Prompt ASR. It supports the following features: + 1. Select a tuple of (text, pre_text, style_text) randomly from a + list of texts as supervisions. + + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + text_sampling_func: Optional[Callable[[List[str]], str]] = None, + rare_word_list: Optional[List[str]] = None, + ): + """ + Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py + for more details. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + :param text_sampling_func: Sampling a text as transcription from a list of texts. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # a text sampling function + self.text_sampling_func = text_sampling_func + self.rare_word_list = rare_word_list + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_frames and max_cuts. + """ + validate_for_asr(cuts) + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "supervisions": default_collate( + [ + self.text_sampling_func( + texts=supervision.texts, + pre_texts=supervision.pre_texts, + context_list=supervision.context_list + if "context_list" in supervision.custom + else None, + rare_word_list=self.rare_word_list, + ) + if self.text_sampling_func is not None + else { + "text": train_text_normalization(supervision.texts[0]), + "pre_text": train_text_normalization(supervision.pre_texts[0]), + "style_text": train_text_normalization( + supervision.pre_texts[0] + ), + "transform_ids": 0, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + +def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str: + """A helper function that generates a random substring from a given string + + Args: + s (str): Input string + + Returns: + str: Returned substring + """ + min_len = min(len(s), min_len) + + start = random.randint(0, len(s) - min_len) + end = min(start + max_len, random.randint(start + min_len, len(s))) + + return s[start:end] + + +def triplet_text_sampling( + texts: List[str], + pre_texts: List[str], + context_list: Optional[str] = None, + rare_word_list: Optional[List[str]] = None, + transforms: Optional[List[Callable[[str], str]]] = None, + min_len_style: Optional[int] = 80, +) -> Dict[str, str]: + """This function generates a triplet of + (pre_text, style_text, ref_text). The style of style_text and ref_text + should **always** match, whereas the style of pre_text is arbitrary. + Suppose we have 2 different transforms A,B, and the preceding text is + referred to as pre_text. The following three tuples are all valid: + + (A(pre_text), A(style_text), A(ref_text)) + (A(pre_text), B(style_text), B(ref_text)) + (A(pre_text), A(style_text), A(ref_text)) + (B(pre_text), B(style_text), B(ref_text)) + + If transforms is not given, the following pre-defined transforms + are available: + 0: original (mixed-cased, with punc) + 1: upper_only_alpha (upper-cased, no punc) + + When the transform of text and pre_text match, we can use the whole + pre_text as the prompt text. + + Args: + texts (List[str]): + A list of ref_texts whose first item is the ground truth + text from books. + pre_texts (List[str]): + A list of pre_texts, whose first item is the groundtruth + pre_text from books. + context_list: Optional[str] = None, + A list of biasing words separated by space + rare_word_list: Optional[str] = None, + A list of rare-words separated by space (used as distractors) + transforms (List[Callable[[str], str]]): A list of possible transforms to be applied + + Returns: + A dictionary of ref_text, pre_text, style_text + """ + assert len(texts) == len(pre_texts) + assert len(texts) == 2 + + # we assume the first item to be ground truth + gt_text = texts[0] + gt_pre_text = pre_texts[0] + + if transforms is None: + transforms = [ + lambda x: x, # return it self + upper_only_alpha, + lower_only_alpha, + lower_all_char, + ] + + sampling_weight = [ + 0.7, + 0.3, + 0.0, + 0.0, + ] # Mixed-punc should have the largest sampling prob + + total_transforms = len(transforms) # do not use the recognized trans + + # Randomly sample transforms + i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) + + # get the normalized text and pre_text + text = transforms[i_text](gt_text) + pre_text = transforms[i_pre_text](gt_pre_text) + + if i_text == i_pre_text: + style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) + else: + # get the pre_text of same style as text + # For now, **don't** do transform to the style text, because we do it after the dataloader + style_text = gt_pre_text + # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) + style_text = get_substring(style_text, min_len=min_len_style, max_len=150) + + return { + "text": train_text_normalization(text), + "pre_text": train_text_normalization(pre_text), + "style_text": train_text_normalization(style_text), + "transform_ids": i_text, + } + + +def triplet_text_sampling_with_context_list( + texts: List[str], + pre_texts: List[str], + context_list: str, + rare_word_list: List[str], + transforms: Optional[List[Callable[[str], str]]] = None, + min_len_style: Optional[int] = 80, +) -> Dict[str, str]: + """This function generates a triplet of + (pre_text, style_text, ref_text). The pre_text is either the preceding text + or a list of words (context words + distractors). + The style of style_text and ref_text should **always** match, whereas + the style of pre_text is arbitrary. + Suppose we have 2 different transforms A,B, and the preceding text is + referred to as pre_text. The following three tuples are all valid: + + (A(pre_text), A(style_text), A(ref_text)) + (A(pre_text), B(style_text), B(ref_text)) + (A(pre_text), A(style_text), A(ref_text)) + (B(pre_text), B(style_text), B(ref_text)) + + If transforms is not given, the following pre-defined transforms + are available: + 0: original (mixed-cased, with punc) + 1: upper_only_alpha (upper-cased, no punc) + + When the transform of text and pre_text match, we can use the whole + pre_text as the prompt text. + + Args: + texts (List[str]): + A list of ref_texts whose first item is the ground truth + text from books. + pre_texts (List[str]): + A list of pre_texts, whose first item is the groundtruth + pre_text from books. + context_list: Optional[str] = None, + A list of biasing words separated by space + rare_word_list: Optional[str] = None, + A list of rare-words separated by space (used as distractors) + transforms (List[Callable[[str], str]]): A list of possible transforms to be applied + + Returns: + A dictionary of ref_text, pre_text, style_text + Returns: + str: A dictionary + """ + # import pdb; pdb.set_trace() + assert len(texts) == len(pre_texts) + assert len(texts) == 2 + + if context_list is not None: + context_list = context_list.lower() + + # we assume the first item to be ground truth + gt_text = texts[0] + gt_pre_text = pre_texts[0] + + if transforms is None: + transforms = [ + lambda x: x, # return it self + upper_only_alpha, + lower_only_alpha, + lower_all_char, + ] + + sampling_weight = [ + 0.7, + 0.3, + 0.0, + 0.0, + ] # Mixed-punc should have the largest sampling prob + + total_transforms = len(transforms) # do not use the recognized trans + + # Select a transformation randomly + i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) + + # get the normalized text and pre_text + text = transforms[i_text](gt_text) + pre_text = get_pre_text_with_context_list2( + text=gt_text, + pre_text=gt_pre_text, + context_list=context_list, + rare_words_list=rare_word_list, + ) + pre_text = transforms[i_pre_text](pre_text) + + if i_text == i_pre_text: + style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) + else: + # get the pre_text of same style as text + # For now, **don't** do transform to the style text + style_text = gt_pre_text + # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) + style_text = get_substring(style_text, min_len=min_len_style, max_len=150) + + return { + "text": train_text_normalization(text), + "pre_text": train_text_normalization(pre_text), + "style_text": train_text_normalization(style_text), + "transform_ids": i_text, + } + + +def get_pre_text_with_context_list( + text: str, + pre_text: str, + context_list: str, + rare_words_list: List[str] = None, +) -> str: + # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha + # By a small proportion of time, use the substring of ref_text as pre_text + + if context_list != "" and context_list is not None: + v = random.random() + if v < 0.5: + # correct + distractors + # sample distractors + num_distractors = random.randint(0, 50) + distractors = random.sample(rare_words_list, num_distractors) + # sample correct + correct = context_list.split() + i = random.randint(1, len(correct)) + correct = random.sample(correct, i) + # combine correct and distractors + pre_text = distractors + correct + random.shuffle(pre_text) + pre_text = " ".join(pre_text) + elif v < 0.7: + splitted = text.split() + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + num_distractors = random.randint(0, 70) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + pre_text = " ".join(splitted) + else: + pre_text = pre_text + else: + v = random.random() + if v < 0.1: + splitted = text.split() + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + pre_text = " ".join(splitted) + num_distractors = random.randint(0, 70) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + elif v < 0.2: + # full distractors + num_distractors = random.randint(5, 100) + distractors = random.sample(rare_words_list, num_distractors) + pre_text = " ".join(distractors) + + elif v < 0.3: + pre_text = get_substring(text, min_len=15, max_len=150) + else: + pre_text = pre_text + + return pre_text + + +def get_pre_text_with_context_list2( + text: str, + pre_text: str, + context_list: str, + rare_words_list: List[str] = None, +) -> str: + # Get the pre_text, either the ground truth preceding text or + # a list of words consisting of biasing words and distrators + # By a small proportion of time, use the substring of ref_text as pre_text + + if context_list != "" and context_list is not None: + v = random.random() + if v < 0.4: + # sample distractors + num_distractors = random.randint(50, 100) + distractors = random.sample(rare_words_list, num_distractors) + # sample correct + correct = context_list.split() + i = random.randint(1, len(correct)) + correct = random.sample(correct, i) + # combine correct and distractors + pre_text = distractors + correct + random.shuffle(pre_text) + pre_text = " ".join(pre_text) + elif v < 0.55: + splitted = text.split() + sampling_weights = [ + len(w) ** 1.2 for w in splitted + ] # longer words with higher weights + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + num_distractors = random.randint(50, 100) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + pre_text = " ".join(splitted) + else: + pre_text = pre_text + else: + v = random.random() + if v < 0.3: + splitted = text.split() + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + pre_text = " ".join(splitted) + num_distractors = random.randint(50, 100) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + elif v < 0.4: + # full distractors + num_distractors = random.randint(5, 100) + distractors = random.sample(rare_words_list, num_distractors) + pre_text = " ".join(distractors) + elif v < 0.6: + pre_text = get_substring(text, min_len=15, max_len=150) + else: + pre_text = pre_text + + return pre_text + + +def naive_triplet_text_sampling( + texts: List[str], + pre_texts: List[str], + context_list: str = None, + rare_word_list: List[str] = None, + min_len_style: Optional[int] = 120, +): + # The most simplest text sampling function, used only for + # evaluation, use a fixed sentence as the style text + + return { + "text": train_text_normalization(texts[0]), + "pre_text": train_text_normalization(pre_texts[0]), + "style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?", + "transform_ids": 0, + } + + +def random_shuffle_subset( + data: List[str], + p: float = 0.2, + p_mask: float = 0.05, +) -> List[str]: + """ + Randomly shuffle the subset by probability `p`, which means that p% of the samples + in the original batch are shuffled, the others are kept in the original order. + + With a probability of `p_mask`, replace the original string with an empty string. + + """ + + num_to_shuffle = int(len(data) * p) + id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False) + item_to_shuffle = [data[id] for id in id_to_shuffle] + random.shuffle(item_to_shuffle) + + for id, item in zip(id_to_shuffle, item_to_shuffle): + data[id] = item + + # Randomly mask a proportion of the data to empty string + if p_mask > 0: + for i in range(len(data)): + if random.random() < p_mask: + data[i] = "" + + return data + + +if __name__ == "__main__": + texts = [ + "AA, BB, cC, dD!", + "AA BB CC DD", + ] + + pre_texts = [ + "EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?", + "EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG", + ] + for i in range(10): + print(f"Run: {i}") + print(triplet_text_sampling(texts, pre_texts)) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py new file mode 100644 index 000000000..6a3bab3c8 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py @@ -0,0 +1,791 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import greedy_search, greedy_search_batch, modified_beam_search +from ls_text_normalization import word_normalization +from text_normalization import ( + ref_text_normalization, + remove_non_alphabetic, + upper_only_alpha, +) +from train_baseline import add_model_arguments, get_params, get_transducer_model +from utils import write_error_stats + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=True, + help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", + ) + + parser.add_argument( + "--long-audio-recog", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--use-ls-test-set", + type=str2bool, + default=False, + help="Use librispeech test set for evaluation.", + ) + + parser.add_argument( + "--compute-CER", + type=str2bool, + default=True, + help="Reports CER. By default, only reports WER", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + texts = batch["supervisions"]["text"] + batch_size = feature.size(0) + + # Get the transducer encoder output + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=feature, + feature_lens=feature_lens, + ) + + hyps = [] + + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + if not params.use_ls_test_set: + book_names = [ + cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"] + ] + else: + book_names = ["" for _ in cut_ids] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, book_name, hyp_words, ref_text in zip( + cut_ids, book_names, hyps, texts + ): + ref_text = ref_text_normalization(ref_text) + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + # if not params.use_ls_test_set: + # results[name + " " + book_name].extend(this_batch) + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + biasing_words: List[str] = None, +): + test_set_wers = dict() + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + biasing_words=biasing_words, + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + if params.compute_CER: + # Write CER statistics + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results, char_level=True) + errs_filename = ( + params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=params.compute_CER, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed CER stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + if params.compute_CER: + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tcER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{} CER\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "modified_beam_search", + ) + + if params.long_audio_recog: + params.res_dir = params.exp_dir / (params.decoding_method + "long_audio") + else: + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + LM = None + + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + test_clean_cuts = libriheavy.test_clean_cuts() + test_other_cuts = libriheavy.test_other_cuts() + ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() + ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() + long_audio_cuts = libriheavy.long_audio_cuts() + + test_clean_dl = libriheavy.valid_dataloaders( + test_clean_cuts, + ) + test_other_dl = libriheavy.valid_dataloaders( + test_other_cuts, + ) + ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) + ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) + long_audio_dl = libriheavy.valid_dataloaders( + long_audio_cuts, + ) + + if params.use_ls_test_set: + test_sets = ["ls-test-clean", "ls-test-other"] + test_dl = [ls_test_clean_dl, ls_test_other_dl] + else: + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + if params.long_audio_recog: + test_sets = ["long-audio"] + test_dl = [long_audio_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + if params.use_ls_test_set: + f = open( + "data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r" + ) + biasing_words = f.read().strip().split() + f.close() + else: + biasing_words = None + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if params.post_normalization: + if "-post-normalization" not in params.suffix: + params.suffix += "-post-normalization" + + new_res = {} + for k in results_dict: + new_ans = [] + for item in results_dict[k]: + id, ref, hyp = item + if params.use_ls_test_set: + hyp = ( + " ".join(hyp).replace("-", " ").split() + ) # handle the hypens + hyp = upper_only_alpha(" ".join(hyp)).split() + hyp = [word_normalization(w.upper()) for w in hyp] + hyp = " ".join(hyp).split() + hyp = [w for w in hyp if w != ""] + ref = upper_only_alpha(" ".join(ref)).split() + else: + hyp = upper_only_alpha(" ".join(hyp)).split() + ref = upper_only_alpha(" ".join(ref)).split() + new_ans.append((id, ref, hyp)) + new_res[k] = new_ans + + save_results( + params=params, + test_set_name=test_set, + results_dict=new_res, + biasing_words=biasing_words, + ) + + if params.suffix.endswith("-post-normalization"): + params.suffix = params.suffix.replace("-post-normalization", "") + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py new file mode 100755 index 000000000..e71999b0a --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py @@ -0,0 +1,1025 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method greedy_search \ + --text-encoder-type BERT \ + --memory-layer 0 \ + --use-pre-text True \ + --use-style-prompt True \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc \ + --compute-CER 0 + + +(2) modified beam search +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --text-encoder-type BERT \ + --memory-layer 0 \ + --use-pre-text True \ + --use-style-prompt True \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc \ + --compute-CER 0 + +(3) Decode LibriSpeech + +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --use-ls-test-set True \ + --beam-size 4 \ + --text-encoder-type BERT \ + --memory-layer 0 \ + --use-pre-text True \ + --use-style-prompt True \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc \ + --compute-CER 0 + +(4) Decode LibriSpeech + biasing list + +biasing_list=100 # could also be 0 + +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --use-ls-test-set True \ + --use-ls-context-list True \ + --biasing-level utterance \ + --ls-distractors $biasing_list \ + --post-normalization True \ + --text-encoder-type BERT \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc + + +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import greedy_search, greedy_search_batch, modified_beam_search +from dataset import naive_triplet_text_sampling, random_shuffle_subset +from ls_text_normalization import word_normalization +from text_normalization import ( + _apply_style_transform, + lower_all_char, + lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from train_bert_encoder import ( + _encode_texts_as_bytes_with_tokenizer, + add_model_arguments, + get_params, + get_tokenizer, + get_transducer_model, +) +from transformers import BertModel, BertTokenizer +from utils import brian_biasing_list, get_facebook_biasing_list, write_error_stats + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-pre-text", + type=str2bool, + default=True, + help="Use pre-text is available during decoding", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Use style prompt when evaluation", + ) + + parser.add_argument( + "--max-prompt-lens", + type=int, + default=1000, + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=True, + help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", + ) + + parser.add_argument( + "--compute-CER", + type=str2bool, + default=False, + help="Reports CER. By default, only reports WER", + ) + + parser.add_argument( + "--style-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of style prompt, i.e style_text", + ) + + parser.add_argument( + "--pre-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of content prompt, i.e pre_text", + ) + + parser.add_argument( + "--use-ls-test-set", + type=str2bool, + default=False, + help="Use librispeech test set for evaluation.", + ) + + parser.add_argument( + "--use-ls-context-list", + type=str2bool, + default=False, + help="If use a fixed context list for LibriSpeech decoding", + ) + + parser.add_argument( + "--biasing-level", + type=str, + default="utterance", + choices=["utterance", "Book", "Chapter"], + ) + + parser.add_argument( + "--ls-distractors", + type=int, + default=0, + help="The number of distractors into context list for LibriSpeech decoding", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + tokenizer: spm.SentencePieceProcessor, + batch: dict, + biasing_dict: dict = None, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + tokenizer: + Tokenizer for the text encoder + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + biasing_dict: + A dictionary in the form `{cut_id: :w1 w2"}` that contains a list + of biasing words (separated with space) + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + cuts = batch["supervisions"]["cut"] + cut_ids = [c.supervisions[0].id for c in cuts] + batch_size = feature.size(0) + + if "pre_text" in batch["supervisions"] and params.use_pre_text: + pre_texts = batch["supervisions"]["pre_text"] + pre_texts = [train_text_normalization(t) for t in pre_texts] + else: + pre_texts = ["" for _ in range(batch_size)] + + # get the librispeech biasing data + if params.use_pre_text and (params.use_ls_context_list and params.use_ls_test_set): + if params.biasing_level == "utterance": + pre_texts = [biasing_dict[id] for id in cut_ids] + elif params.biasing_level == "Chapter": + chapter_ids = [c.split("-")[1] for c in cut_ids] + pre_texts = [biasing_dict[id] for id in chapter_ids] + elif params.biasing_level == "Book": + chapter_ids = [c.split("-")[1] for c in cut_ids] + pre_texts = [biasing_dict[id] for id in chapter_ids] + else: + raise ValueError(f"Unseen biasing level: {params.biasing_level}") + if params.pre_text_transform == "mixed-punc": + pre_texts = [t.lower() for t in pre_texts] + + # get style_text + if params.use_style_prompt: + fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." + style_texts = batch["supervisions"].get( + "style_text", [fixed_sentence for _ in range(batch_size)] + ) + style_texts = [train_text_normalization(t) for t in style_texts] + else: + style_texts = ["" for _ in range(batch_size)] # use empty string + + # Get the text embedding + if params.use_pre_text or params.use_style_prompt: + # apply style transform to the pre_text and style_text + pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) + if not params.use_ls_context_list: + pre_texts = [t[-params.max_prompt_lens :] for t in pre_texts] + + if params.use_style_prompt: + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # Use tokenizer to prepare input for text encoder + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_texts, + style_texts=style_texts, + tokenizer=tokenizer, + device=device, + no_limit=True, + ) + logging.info( + f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}" + ) + + memory, memory_key_padding_mask = model.encode_text( + encoded_inputs=encoded_inputs, + style_lens=style_lens, + ) # (T,B,C) + else: + memory = None + memory_key_padding_mask = None + + # Get the transducer encoder output + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=feature, + feature_lens=feature_lens, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + hyps = [] + + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + tokenizer: spm.SentencePieceProcessor, + biasing_dict: Dict = None, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + tokenizer: + Tokenizer for the text encoder + biasing_dict: + A dictionary in the form `{cut_id: :w1 w2"}` that contains a list + of biasing words (separated with space) + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"][ + "text" + ] # By default, this should be in mixed-punc format + + # the style of ref_text should match style_text + texts = _apply_style_transform(texts, params.style_text_transform) + if params.use_style_prompt: + texts = _apply_style_transform(texts, params.style_text_transform) + + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + if not params.use_ls_test_set: + try: + book_names = [ + cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"] + ] + except AttributeError: + book_names = [ + cut.id.split("/")[0] for cut in batch["supervisions"]["cut"] + ] + else: + book_names = ["" for _ in cut_ids] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + biasing_dict=biasing_dict, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, book_name, hyp_words, ref_text in zip( + cut_ids, book_names, hyps, texts + ): + ref_text = ref_text_normalization( + ref_text + ) # remove full-width symbols & some book marks + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + biasing_words: List[str] = None, +): + test_set_wers = dict() + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + if params.compute_CER: + # Write CER statistics + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results, char_level=True) + errs_filename = ( + params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=params.compute_CER, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed CER stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + if params.compute_CER: + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{} CER\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "modified_beam_search", + ) + + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_pre_text: + params.suffix += ( + f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}" + ) + + if params.use_style_prompt: + params.suffix += f"-style-prompt-{params.style_text_transform}" + + if params.use_ls_context_list: + assert ( + params.use_pre_text + ), "Must set --use-pre-text to True if using context list" + params.suffix += f"-use-{params.biasing_level}-level-ls-context-list" + if params.biasing_level == "utterance" and params.ls_distractors: + params.suffix += f"-ls-context-distractors-{params.ls_distractors}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + LM = None + + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + test_clean_cuts = libriheavy.test_clean_cuts() + test_other_cuts = libriheavy.test_other_cuts() + ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() + ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() + + test_clean_dl = libriheavy.valid_dataloaders( + test_clean_cuts, text_sampling_func=naive_triplet_text_sampling + ) + test_other_dl = libriheavy.valid_dataloaders( + test_other_cuts, text_sampling_func=naive_triplet_text_sampling + ) + ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) + ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) + + if params.use_ls_test_set: + test_sets = ["ls-test-clean", "ls-test-other"] + test_dl = [ls_test_clean_dl, ls_test_other_dl] + else: + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + biasing_dict = None + if params.use_ls_context_list: + if test_set == "ls-test-clean": + biasing_dict = get_facebook_biasing_list( + test_set="test-clean", + num_distractors=params.ls_distractors, + ) + elif test_set == "ls-test-other": + biasing_dict = get_facebook_biasing_list( + test_set="test-other", + num_distractors=params.ls_distractors, + ) + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + biasing_dict=biasing_dict, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if params.post_normalization: + if "-post-normalization" not in params.suffix: + params.suffix += "-post-normalization" + + new_res = {} + for k in results_dict: + new_ans = [] + for item in results_dict[k]: + id, ref, hyp = item + if params.use_ls_test_set: + hyp = ( + " ".join(hyp).replace("-", " ").split() + ) # handle the hypens + hyp = upper_only_alpha(" ".join(hyp)).split() + hyp = [word_normalization(w.upper()) for w in hyp] + hyp = " ".join(hyp).split() + hyp = [w for w in hyp if w != ""] + ref = upper_only_alpha(" ".join(ref)).split() + else: + hyp = upper_only_alpha(" ".join(hyp)).split() + ref = upper_only_alpha(" ".join(ref)).split() + new_ans.append((id, ref, hyp)) + new_res[k] = new_ans + + save_results( + params=params, + test_set_name=test_set, + results_dict=new_res, + ) + + if params.suffix.endswith("-post-normalization"): + params.suffix = params.suffix.replace("-post-normalization", "") + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py new file mode 100755 index 000000000..4559ebb6d --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py @@ -0,0 +1,963 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import ( + greedy_search, + greedy_search_batch, + greedy_search_batch_with_context, + greedy_search_with_context, + modified_beam_search, +) +from dataset import naive_triplet_text_sampling, random_shuffle_subset +from lhotse import load_manifest_lazy +from text_normalization import ( + lower_all_char, + lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from train_bert_encoder_with_style import ( + _encode_texts_as_bytes_with_tokenizer, + add_model_arguments, + get_params, + get_tokenizer, + get_transducer_model, +) +from transformers import BertModel, BertTokenizer +from utils import get_facebook_biasing_list + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--log-dir", + type=str, + required=True, + help="Where to store the logs", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--input-manifest", + type=str, + required=True, + help="The input manifest to be decoded", + ) + + parser.add_argument( + "--output-manifest", + type=str, + required=True, + help="Where to store the output manifest (directory)", + ) + + parser.add_argument( + "--use-pre-text", + type=str2bool, + default=True, + help="Use pre-text is available during decoding", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Use style prompt when evaluation", + ) + + parser.add_argument( + "--use-context-embedding", + type=str2bool, + default=False, + help="Use context fuser when evaluation", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=True, + help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", + ) + + parser.add_argument( + "--compute-CER", + type=str2bool, + default=True, + help="Reports CER. By default, only reports WER", + ) + + parser.add_argument( + "--style-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of style prompt, i.e style_text", + ) + + parser.add_argument( + "--pre-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of content prompt, i.e pre_text", + ) + + parser.add_argument( + "--use-ls-test-set", + type=str2bool, + default=False, + help="Use librispeech test set for evaluation.", + ) + + parser.add_argument( + "--use-ls-context-list", + type=str2bool, + default=False, + help="If use a fixed context list for LibriSpeech decoding", + ) + + add_model_arguments(parser) + + return parser + + +def _apply_style_transform(text: List[str], transform: str) -> List[str]: + """Apply transform to a list of text. By default, the text are in + ground truth format, i.e mixed-punc. + + Args: + text (List[str]): Input text string + transform (str): Transform to be applied + + Returns: + List[str]: _description_ + """ + if transform == "mixed-punc": + return text + elif transform == "upper-no-punc": + return [upper_only_alpha(s) for s in text] + elif transform == "lower-no-punc": + return [lower_only_alpha(s) for s in text] + elif transform == "lower-punc": + return [lower_all_char(s) for s in text] + else: + raise NotImplementedError(f"Unseen transform: {transform}") + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + tokenizer, + batch: dict, + biasing_dict: dict = None, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + cuts = batch["supervisions"]["cut"] + cut_ids = [c.supervisions[0].id for c in cuts] + batch_size = feature.size(0) + + # get pre_text + if "pre_text" in batch["supervisions"] and params.use_pre_text: + pre_texts = batch["supervisions"][ + "text" + ] # use the ground truth ref text as pre_text + pre_texts = [train_text_normalization(t) for t in pre_texts] + else: + pre_texts = ["" for _ in range(batch_size)] + + if params.use_ls_context_list: + pre_texts = [biasing_dict[id] for id in cut_ids] + + # get style_text + if params.use_style_prompt: + fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." + style_texts = batch["supervisions"].get( + "style_text", [fixed_sentence for _ in range(batch_size)] + ) + style_texts = [train_text_normalization(t) for t in style_texts] + else: + style_texts = ["" for _ in range(batch_size)] # use empty string + + # Get the text embedding input + if params.use_pre_text or params.use_style_prompt: + + # apply style transform to the pre_text and style_text + pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) + # pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) + if params.use_style_prompt: + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # Use tokenizer to prepare input for text encoder + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_texts, + style_texts=style_texts, + tokenizer=tokenizer, + device=device, + ) + + memory, memory_key_padding_mask = model.encode_text( + encoded_inputs=encoded_inputs, + style_lens=style_lens, + ) # (T,B,C) + else: + memory = None + memory_key_padding_mask = None + + # Get the transducer encoder output + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=feature, + feature_lens=feature_lens, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + hyps = [] + + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + if memory is None or not params.use_context_embedding: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + else: + memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C) + context = model.context_fuser( + memory, padding_mask=memory_key_padding_mask + ) # (N,C) + context = model.joiner.context_proj(context) # (N,C) + hyp_tokens = greedy_search_batch_with_context( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context=context, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + if memory is None or not params.use_context_embedding: + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + else: + cur_context = context[i : i + 1, :] + hyp = greedy_search_with_context( + model=model, + encoder_out=encoder_out_i, + context=cur_context, + max_sym_per_frame=params.max_sym_per_frame, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + tokenizer, + biasing_dict: Dict = None, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 40 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"][ + "text" + ] # By default, this should be in mixed-punc format + + # the style of ref_text should match style_text + texts = _apply_style_transform(texts, params.style_text_transform) + if params.use_style_prompt: + texts = _apply_style_transform(texts, params.style_text_transform) + + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + biasing_dict=biasing_dict, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_text = ref_text_normalization( + ref_text + ) # remove full-width symbols & some book marks + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + if params.compute_CER: + # Write CER statistics + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results, char_level=True) + errs_filename = ( + params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=params.compute_CER, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed CER stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + if params.compute_CER: + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{} CER\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +def add_decoding_result_to_manifest( + in_manifest, + out_manifest: str, + results_dict: Dict, +): + # write the decoding results with prompt to the manifest as an + # extra ref text + new_ans = {} + for key, value in results_dict.items(): + for items in value: + id, ref, hyp = items + new_ans[id] = " ".join(hyp) + + def _add_decoding(c): + key = c.supervisions[0].id + c.supervisions[0].texts.append(new_ans[key]) + return c + + in_manifest = in_manifest.map(_add_decoding) + logging.info(f"Saving manifest to {out_manifest}") + in_manifest.to_file(out_manifest) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + cuts = load_manifest_lazy(args.input_manifest) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + splitted_cuts = cuts.split(num_splits=world_size) + mp.spawn( + run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True + ) + else: + run(rank=0, world_size=1, args=args, cuts=cuts) + + +@torch.no_grad() +def run(rank, world_size, args, cuts): + params = get_params() + params.update(vars(args)) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_pre_text: + params.suffix += f"-pre-text-{params.pre_text_transform}" + + if params.use_style_prompt: + params.suffix += f"-style-prompt-{params.style_text_transform}" + + params.suffix += f"-{rank}" + + world_size = params.world_size + + params.output_manifest = Path(params.output_manifest) + if world_size > 1: + cuts = cuts[rank] + out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz" + else: + out_name = params.output_manifest / "with_decoding.jsonl.gz" + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}") + logging.info("Decoding started") + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + LM = None + + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + dl = libriheavy.valid_dataloaders( + cuts, text_sampling_func=naive_triplet_text_sampling + ) + + test_sets = ["test"] + test_dl = [dl] + + for test_set, test_dl in zip(test_sets, test_dl): + biasing_dict = None + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + biasing_dict=biasing_dict, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + # save_results( + # params=params, + # test_set_name=test_set, + # results_dict=results_dict, + # ) + + add_decoding_result_to_manifest( + in_manifest=cuts, + out_manifest=out_name, + results_dict=results_dict, + ) + + logging.info("Done!") + + +# torch.set_num_threads(1) +# torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py new file mode 100644 index 000000000..93e0f9f7e --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py @@ -0,0 +1,130 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scaling import Balancer + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + padding_idx=blank_id, + ) + # the balancers are to avoid any drift in the magnitude of the + # embeddings, which would interact badly with parameter averaging. + self.balancer = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) + + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim // 4, # group size == 4 + bias=False, + ) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + + embedding_out = self.balancer(embedding_out) + + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + embedding_out = self.balancer2(embedding_out) + + return embedding_out diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py b/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py new file mode 100644 index 000000000..257facce4 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py @@ -0,0 +1,43 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + + +class EncoderInterface(nn.Module): + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (batch_size, input_seq_len, num_features) + containing the input features. + x_lens: + A tensor of shape (batch_size,) containing the number of frames + in `x` before padding. + Returns: + Return a tuple containing two tensors: + - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) + containing unnormalized probabilities, i.e., the output of a + linear layer. + - encoder_out_lens, a tensor of shape (batch_size,) containing + the number of frames in `encoder_out` before padding. + """ + raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py b/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py new file mode 100644 index 000000000..e0bc556a8 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +""" +Export `model.state_dict()` + +- For non-streaming model: + +./zipformer_prompt_asr/export_PromptASR.py \ + --exp-dir ./zipformer_prompt_asr/exp \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --epoch 50 \ + --avg 10 + +- For streaming model: + +./zipformer_prompt_asr/export_PromptASR.py \ + --exp-dir ./zipformer_prompt_asr/exp \ + --causal 1 \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --epoch 50 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from torch import Tensor, nn +from train_bert_encoder import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + assert params.jit is False, "Jit is not supported yet" + + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py b/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py new file mode 100644 index 000000000..59f822748 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py @@ -0,0 +1,86 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from scaling import ScaledLinear + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + context_dim: int = 512, + context_injection: bool = False, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + if context_injection: + self.context_proj = ScaledLinear( + context_dim, joiner_dim, initial_scale=0.25 + ) + else: + self.context_proj = None + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + context: torch.Tensor = None, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + context: + An embedding vector representing the previous context information + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape[:-1] == decoder_out.shape[:-1] + + if project_input: + if context: + logit = ( + self.encoder_proj(encoder_out) + + self.decoder_proj(decoder_out) + + self.context_proj(context) + ) + else: + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + else: + if context is not None: + logit = encoder_out + decoder_out + context.unsqueeze(1).unsqueeze(1) + else: + logit = encoder_out + decoder_out + + logit = self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py b/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py new file mode 100644 index 000000000..9a693ca4f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py @@ -0,0 +1,153 @@ +import re + +words = { + 0: "zero", + 1: "one", + 2: "two", + 3: "three", + 4: "four", + 5: "five", + 6: "six", + 7: "seven", + 8: "eight", + 9: "nine", + 10: "ten", + 11: "eleven", + 12: "twelve", + 13: "thirteen", + 14: "fourteen", + 15: "fifteen", + 16: "sixteen", + 17: "seventeen", + 18: "eighteen", + 19: "nineteen", + 20: "twenty", + 30: "thirty", + 40: "forty", + 50: "fifty", + 60: "sixty", + 70: "seventy", + 80: "eighty", + 90: "ninety", +} +ordinal_nums = [ + "zeroth", + "first", + "second", + "third", + "fourth", + "fifth", + "sixth", + "seventh", + "eighth", + "ninth", + "tenth", + "eleventh", + "twelfth", + "thirteenth", + "fourteenth", + "fifteenth", + "sixteenth", + "seventeenth", + "eighteenth", + "nineteenth", + "twentieth", +] + +num_ordinal_dict = {num: ordinal_nums[num] for num in range(21)} + + +def year_to_words(num: int): + assert isinstance(num, int), num + # check if a num is representing a year + if num > 1500 and num < 2000: + return words[num // 100] + " " + num_to_words(num % 100) + elif num == 2000: + return "TWO THOUSAND" + elif num > 2000: + return "TWO THOUSAND AND " + num_to_words(num % 100) + else: + return num_to_words(num) + + +def num_to_words(num: int): + # Return the English words of a integer number + + # If this is a year number + if num > 1500 and num < 2030: + return year_to_words(num) + + if num < 20: + return words[num] + if num < 100: + if num % 10 == 0: + return words[num // 10 * 10] + else: + return words[num // 10 * 10] + " " + words[num % 10] + if num < 1000: + return words[num // 100] + " hundred and " + num_to_words(num % 100) + if num < 1000000: + return num_to_words(num // 1000) + " thousand " + num_to_words(num % 1000) + return num + + +def num_to_ordinal_word(num: int): + + return num_ordinal_dict.get(num, num_to_words(num)).upper() + + +def replace_full_width_symbol(s: str) -> str: + # replace full-width symbol with theri half width counterpart + s = s.replace("“", '"') + s = s.replace("”", '"') + s = s.replace("‘", "'") + s = s.replace("’", "'") + + return s + + +def decoding_normalization(text: str) -> str: + text = replace_full_width_symbol(text) + + # Only keep all alpha-numeric characters, hypen and apostrophe + text = text.replace("-", " ") + text = re.sub(r"[^a-zA-Z0-9\s']+", "", text) + return text + + +def word_normalization(word: str) -> str: + # 1 .Use full word for some abbreviation + # 2. Convert digits to english words + # 3. Convert ordinal number to english words + if word == "MRS": + return "MISSUS" + if word == "MR": + return "MISTER" + if word == "ST": + return "SAINT" + if word == "ECT": + return "ET CETERA" + if word.isnumeric(): + word = num_to_words(int(word)) + return str(word).upper() + # e.g 9TH, 6TH + if word[-2:] == "TH" and word[0].isnumeric(): + return num_to_ordinal_word(int(word[:-2])).upper() + if word[0] == "'": + return word[1:] + + return word + + +def simple_normalization(text: str) -> str: + text = replace_full_width_symbol(text) + text = text.replace("--", " ") + + return text + + +if __name__ == "__main__": + + s = str(1830) + out = word_normalization(s) + print(s, out) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py new file mode 100644 index 000000000..77b4057c4 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py @@ -0,0 +1,262 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import warnings +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear, penalize_abs_values_gt +from torch import Tensor + +from icefall.utils import add_sos, make_pad_mask + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, + vocab_size, + initial_scale=0.25, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, + vocab_size, + initial_scale=0.25, + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + text: + A 2-D tensor of integer dtype containing prompt text, of shape (N, T). + It is exptected to contain the style prompt (first) and then the content + prompt. + text_lens: + A 1-D tensor of shape (N,). It contains the number of elements (bytes) + in `text` before padding, which will include the lengths of the + style plus the content prompt. + style_lens: + A 1-D tensor of shape (N,), containing the number of elements (bytes) + within each row of `text` that correspond to the style prompt (these + are expected to come first). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + x, x_lens = self.encoder_embed(x, x_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, x_lens = self.encoder( + x, + x_lens, + src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) + + def encode_audio( + self, + feature: Tensor, + feature_lens: Tensor, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Encode the input audio features + + Args: + feature (Tensor): Input audio (N,T,C) + feature_lens (Tensor): Length of input audio (N,) + Returns: + Tuple[Tensor, Tensor]: Encoded acoustic features and length + """ + x, x_lens = self.encoder_embed(feature, feature_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder( + x=x, + x_lens=x_lens, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py new file mode 100644 index 000000000..21c7b4fac --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py @@ -0,0 +1,392 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import warnings +from typing import Dict, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear, penalize_abs_values_gt +from torch import Tensor + +from icefall.utils import add_sos, make_pad_mask + + +class PromptedTransducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + text_encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + use_BERT: bool = True, + text_encoder_type: str = "BERT", + text_encoder_adapter: bool = False, + freeze_text_encoder: bool = True, + context_fuser: nn.Module = None, + ): + """ + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + text_encoder: + This is a encoder that processes text information (e.g content prompt + and style prompt). The input is `x` of (N,T) and `x_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + text_encoder_type: + The type of the text_encoder. Supported are (BERT, DistilBERT) + context_fuser + A optional module that fuses the embeddings of text encoder. The fused embedding + will be added to the joiner. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.text_encoder = text_encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, + vocab_size, + initial_scale=0.25, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, + vocab_size, + initial_scale=0.25, + ) + + self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT + self.context_fuser = context_fuser + + assert text_encoder_type in ( + "BERT", + "DistilBERT", + "BERT-UNCASED", + ), f"Unseen text_encoder type {text_encoder_type}" + self.text_encoder_dim = ( + self.text_encoder.config.hidden_size + if text_encoder_type in ("BERT", "BERT-UNCASED") + else self.text_encoder.config.dim + ) + self.freeze_text_encoder = freeze_text_encoder + + if text_encoder_adapter: + self.text_encoder_adapter = nn.Sequential( + nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False), + nn.Tanh(), + ) + else: + self.text_encoder_adapter = None + + self.style_prompt_embedding = nn.Parameter( + torch.full((self.text_encoder_dim,), 0.5) + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + encoded_inputs: Dict, + style_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + use_pre_text: bool = True, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + text: + A 2-D tensor of integer dtype containing prompt text, of shape (N, T). + It is exptected to contain the style prompt (first) and then the content + prompt. + text_lens: + A 1-D tensor of shape (N,). It contains the number of elements (bytes) + in `text` before padding, which will include the lengths of the + style plus the content prompt. + style_lens: + A 1-D tensor of shape (N,), containing the number of elements (bytes) + within each row of `text` that correspond to the style prompt (these + are expected to come first). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + if self.freeze_text_encoder: + self.text_encoder.eval() + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + x, x_lens = self.encoder_embed(x, x_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # freeze the BERT text encoder + + if use_pre_text: + memory, memory_key_padding_mask = self.encode_text( + encoded_inputs, style_lens=style_lens + ) + else: + memory = None + memory_key_padding_mask = None + + encoder_out, x_lens = self.encoder( + x, + x_lens, + src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + if self.context_fuser is not None and memory is not None: + memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C) + context = self.context_fuser(memory, padding_mask=memory_key_padding_mask) + context = self.joiner.context_proj(context) + else: + context = None + + logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) + + def _add_style_indicator(self, memory: Tensor, style_lens: Tensor): + """ + Adds to `memory` an indicator that is 1.0 for positions that correspond to + the `style prompt` and 0 elsewhere. The scale can be fixed because the + scale of the embedding vector can adjust to compensate. + + Args: + memory: (memory_len, batch_size, embed_dim) + style_lens: (batch_size,), a vector of lengths of the style prompt. + """ + + (memory_len, batch_size, embed_dim) = memory.shape + + indicator = ( + torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens + ) + indicator = indicator.to(memory.dtype) + + extra_term = torch.zeros_like(memory) + extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand( + memory_len, batch_size, self.text_encoder_dim + ) + + return memory + extra_term + + def encode_text( + self, + encoded_inputs: Dict, + style_lens: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Get the embeddings of text + + Args: + encoded_inputs: The encoded inputs generated by a tokenizer (Dict) + + Returns: + Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the + text_encoder and the attention mask + """ + text_lens = encoded_inputs.pop("length") # need to use pop to remove this item + + # Freeze the pre-trained text encoder + with torch.no_grad(): + memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C) + memory = memory.permute(1, 0, 2) + + # Text encoder adapter + if self.text_encoder_adapter is not None: + memory = self.text_encoder_adapter(memory) + + memory = self._add_style_indicator(memory, style_lens) + + memory_key_padding_mask = make_pad_mask(text_lens) + + return memory, memory_key_padding_mask + + def encode_audio( + self, + feature: Tensor, + feature_lens: Tensor, + memory: Optional[Tensor], + memory_key_padding_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor]: + """Encode the input audio features + + Args: + feature (Tensor): Input audio (N,T,C) + feature_lens (Tensor): Length of input audio (N,) + memory (Tensor): Embeddings from the text encoder + memory_key_padding_mask (Tensor): _description_ + + Returns: + Tuple[Tensor, Tensor]: _description_ + """ + x, x_lens = self.encoder_embed(feature, feature_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder( + x=x, + x_lens=x_lens, + src_key_padding_mask=src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +Transducer = PromptedTransducer # for decoding diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py new file mode 100644 index 000000000..a767761eb --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py @@ -0,0 +1,1168 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): + + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + param_groups, parameters_names = self._get_names_of_parameters(params) + super(ScaledAdam, self).__init__(param_groups, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for cur_group in iterable_or_groups: + assert "named_params" in cur_group + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + param_groups.append(cur_group) + param_groups_names.append(name_list) + + return param_groups, param_groups_names + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + numel = p.numel() + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state, param_names) in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) + return 1.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + return ans + + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter which dominates tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter dominating tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[Union[int, float]] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.25 * ( + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + if random.random() < 0.0005: + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + + from scaling import ScaledLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py new file mode 100644 index 000000000..48fd2612a --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script loads a checkpoint (`pretrained.pt`) and uses it to decode waves. +You can generate the checkpoint with the following command: + +./zipformer/export_PromptASR.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --epoch 50 \ + --avg 10 + +Utterance level context biasing: + +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --method modified_beam_search \ + --use-pre-text True \ + --content-prompt "bessy random words hello k2 ASR" \ + --use-style-prompt True \ + librispeech.flac + + +Word level context biasing: + +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --method modified_beam_search \ + --use-pre-text True \ + --content-prompt "The topic is about horses." \ + --use-style-prompt True \ + test.wav + + +""" + +import argparse +import logging +import math +import warnings +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import greedy_search_batch, modified_beam_search +from text_normalization import _apply_style_transform, train_text_normalization +from torch.nn.utils.rnn import pad_sequence +from train_bert_encoder import ( + _encode_texts_as_bytes_with_tokenizer, + add_model_arguments, + get_params, + get_tokenizer, + get_transducer_model, +) + +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500_fallback_coverage_0.99/bpe.model", + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + parser.add_argument( + "--use-pre-text", + type=str2bool, + default=True, + help="Use content prompt during decoding", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Use style prompt during decoding", + ) + + parser.add_argument( + "--pre-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of content prompt, i.e pre_text", + ) + + parser.add_argument( + "--style-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of style prompt, i.e style_text", + ) + + parser.add_argument( + "--content-prompt", type=str, default="", help="The content prompt for decoding" + ) + + parser.add_argument( + "--style-prompt", + type=str, + default="Mixed-cased English text with punctuations, feel free to change it.", + help="The style prompt for decoding", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) # for text encoder + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + assert ( + len(params.sound_files) == 1 + ), "Only support decoding one audio at this moment" + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # encode prompts + if params.use_pre_text: + pre_text = [train_text_normalization(params.content_prompt)] + pre_text = _apply_style_transform(pre_text, params.pre_text_transform) + else: + pre_text = [""] + + if params.use_style_prompt: + style_text = [params.style_prompt] + style_text = _apply_style_transform(style_text, params.style_text_transform) + else: + style_text = [""] + + if params.use_pre_text or params.use_style_prompt: + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_text, + style_texts=style_text, + tokenizer=tokenizer, + device=device, + no_limit=True, + ) + + memory, memory_key_padding_mask = model.encode_text( + encoded_inputs=encoded_inputs, + style_lens=style_lens, + ) # (T,B,C) + else: + memory = None + memory_key_padding_mask = None + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=features, + feature_lens=feature_lengths, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + if params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + hyps.append(sp.decode(hyp_tokens)[0]) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + hyps.append(sp.decode(hyp_tokens)[0]) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py new file mode 100644 index 000000000..0e6764ba0 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py @@ -0,0 +1,1872 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import logging +import math +import random +from functools import reduce +from itertools import repeat +from typing import Optional, Tuple, Union + +import k2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn import Embedding as ScaledEmbedding + + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + + def __init__(self, *args): + assert len(args) >= 1 + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [(float(x), float(y)) for x, y in args] + for (x, y) in self.pairs: + assert isinstance(x, float) or isinstance(x, int) + assert isinstance(y, float) or isinstance(y, int) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], self.pairs + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear( + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise lienar + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p crosss. + """ + assert isinstance(p, PiecewiseLinear) + + # get sorted x-values without repetition. + x_vals = sorted(set([x for x, y in self.pairs] + [x for x, y in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): + # if the two lines in this subsegment potentially cross each other.. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or in training or mode or in + torch.jit scripting mode. + """ + + def __init__(self, *args, default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) + + def __float__(self): + batch_count = self.batch_count + if batch_count is None or not self.training or torch.jit.is_scripting(): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, default=self.default) + else: + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), default=self.default) + else: + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) + + +FloatLike = Union[float, ScheduledFloat] + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + + p is the proportion of items that should be above the cutoff. + """ + + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = x > self.cutoff + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1 - q) + return ans + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BiasNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return (x - bias) * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + +class BiasNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, + ) -> None: + super(BiasNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + + self.store_output_for_backprop = store_output_for_backprop + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + + if torch.jit.is_scripting(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() + return x * scales + + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) + + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv2d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) + + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: + """ + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0 or chunk_size > seq_len: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + +class BalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + ctx.save_for_backward(x) + ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors + (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = m_loss + r_loss + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + except Exception as e: + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) + + return x_grad, None, None, None, None, None, None + + +class Balancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, + ): + super().__init__() + + if prob is None: + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) + self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) + + # actually self.num_channels is no longer needed except for an assertion. + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.min_abs = min_abs + self.max_abs = max_abs + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): + return _no_op(x) + + prob = float(self.prob) + if random.random() < prob: + # The following inner-functions convert from the way we historically specified + # these limitations, as limits on the absolute value and the proportion of positive + # values, to limits on the RMS value and the (mean / stddev). + def _abs_to_rms(x): + # for normally distributed data, if the expected absolute value is x, the + # expected rms value will be sqrt(pi/2) * x. + return 1.25331413732 * x + + def _proportion_positive_to_mean(x): + def _atanh(x): + eps = 1.0e-10 + # eps is to prevent crashes if x is exactly 0 or 1. + # we'll just end up returning a fairly large value. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 + + def _approx_inverse_erf(x): + # 1 / (sqrt(pi) * ln(2)), + # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions + # this approximation is extremely crude and gets progressively worse for + # x very close to -1 or +1, but we mostly care about the "middle" region + # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, + # and math.erf(0.0407316414078772) = 0.045935330944660666, + # which is pretty close to 0.05. + return 0.8139535143 * _atanh(x) + + # first convert x from the range 0..1 to the range -1..1 which the error + # function returns + x = -1 + (2 * x) + return _approx_inverse_erf(x) + + min_mean = _proportion_positive_to_mean(float(self.min_positive)) + max_mean = _proportion_positive_to_mean(float(self.max_positive)) + min_rms = _abs_to_rms(float(self.min_abs)) + max_rms = _abs_to_rms(float(self.max_abs)) + grad_scale = float(self.grad_scale) + + assert x.shape[self.channel_dim] == self.num_channels + + return BalancerFunction.apply( + x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: + ctx.save_for_backward(x) + ctx.module = module + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + w = ctx.module + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, w.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) + + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None + except Exception as e: + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) + return x_grad, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert float(whitening_limit) >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + if isinstance(prob, float): + prob = (prob, prob) + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob <= self.max_prob <= 1 + self.prob = self.max_prob + self.name = None # will be set in training loop + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + grad_scale = float(self.grad_scale) + if not x.requires_grad or random.random() > self.prob or grad_scale == 0: + return _no_op(x) + else: + return WhiteningPenaltyFunction.apply(x, self) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x, y, name): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + return scale_grad(x, self.alpha) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x,) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.044 + ceil = 1.2 + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) + + +class MulForDropout3(torch.autograd.Function): + # returns (x * y * alpha) where alpha is a float and y doesn't require + # grad and is zero-or-one. + @staticmethod + @custom_fwd + def forward(ctx, x, y, alpha): + assert not y.requires_grad + ans = x * y * alpha + ctx.save_for_backward(ans) + ctx.alpha = alpha + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + (ans,) = ctx.saved_tensors + x_grad = ctx.alpha * ans_grad * (ans != 0) + return x_grad, None, None + + +# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, +# and it lets you choose one dimension to share the dropout mask over +class Dropout3(nn.Module): + def __init__(self, p: FloatLike, shared_dim: int): + super().__init__() + self.p = p + self.shared_dim = shared_dim + + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + scale = 1.0 / (1 - p) + rand_shape = list(x.shape) + rand_shape[self.shared_dim] = 1 + mask = torch.rand(*rand_shape, device=x.device) > p + ans = MulForDropout3.apply(x, mask, scale) + return ans + + +class SwooshLFunction(torch.autograd.Function): + """ + swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + coeff = -0.08 + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + + if not requires_grad: + return y + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = coeff + ceil = 1.0 + coeff + 0.005 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshL(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + if torch.jit.is_scripting(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + if not x.requires_grad: + return k2.swoosh_l_forward(x) + else: + return k2.swoosh_l(x) + # return SwooshLFunction.apply(x) + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshR(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + if torch.jit.is_scripting(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + if not x.requires_grad: + return k2.swoosh_r_forward(x) + else: + return k2.swoosh_r(x) + # return SwooshRFunction.apply(x) + + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +class ActivationDropoutAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): + if dropout_p != 0.0: + dropout_shape = list(x.shape) + if dropout_shared_dim is not None: + dropout_shape[dropout_shared_dim] = 1 + # else it won't be very memory efficient. + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) + else: + dropout_mask = None + + ctx.save_for_backward(x, weight, bias, dropout_mask) + + ctx.activation = activation + + forward_activation_dict = { + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, + } + # it will raise a KeyError if this fails. This will be an error. We let it + # propagate to the user. + activation_func = forward_activation_dict[activation] + x = activation_func(x) + if dropout_mask is not None: + x = x * dropout_mask + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias, dropout_mask) = saved + + forward_and_deriv_activation_dict = { + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, + } + # the following lines a KeyError if the activation is unrecognized. + # This will be an error. We let it propagate to the user. + func = forward_and_deriv_activation_dict[ctx.activation] + + y, func_deriv = func(x) + if dropout_mask is not None: + y = y * dropout_mask + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + if dropout_mask is not None: + # order versus func_deriv does not matter + x_deriv = x_deriv * dropout_mask + + return x_deriv, weight_deriv, bias_deriv, None, None, None + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + layer = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) + + self.weight = layer.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter("bias", layer.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + + def forward(self, x: Tensor): + if torch.jit.is_scripting(): + if self.activation == "SwooshL": + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, self.weight, self.bias) + + return ActivationDropoutAndLinearFunction.apply( + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) + + +class ClipGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, limit: float): + ctx.limit = limit + return x + + @staticmethod + def backward(ctx, x_grad, *args): + return x_grad.clamp(-ctx.limit, ctx.limit), None + + +def clip_grad(x: Tensor, limit: float): + return ClipGradFunction.apply(x, limit) + + +class AbsValuePenalizer(nn.Module): + """ + This module adds a penalty to the loss function when ever the absolute value of + any element of the input tensor exceeds a certain limit. + """ + + def __init__(self, limit: float, prob: float = 0.1, penalty: float = 1.0e-04): + super().__init__() + self.limit = limit + self.penalty = penalty + + self.prob = prob + self.name = None # will be set in training loop + + # 20% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.2) + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or not self.training + or random.random() > self.prob + ): + # or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + return _no_op(x) # the _no_op op is to make our diagnostics code work. + + x = penalize_abs_values_gt( + x, limit=self.limit, penalty=self.penalty, name=self.name + ) + return x + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = Balancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + min_abs=0.0, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_sign: x = ", x) + print("_test_balancer_sign: y grad = ", y_grad) + print("_test_balancer_sign: x grad = ", x.grad) + + +def _test_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = Balancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + min_abs=0.2, + max_abs=0.7, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_magnitude: x = ", x) + print("_test_balancer_magnitude: y grad = ", y_grad) + print("_test_balancer_magnitude: x grad = ", x.grad) + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshL() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshR() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +def _test_piecewise_linear(): + p = PiecewiseLinear((0, 10.0)) + for x in [-100, 0, 100]: + assert p(x) == 10.0 + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: + print("x, y = ", x, y) + assert p(x) == y, (x, p(x), y) + + q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] + pq = p.max(q) + for x in x_vals: + y1 = max(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p.min(q) + for x in x_vals: + y1 = min(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p + q + for x in x_vals: + y1 = p(x) + q(x) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + + +def _test_activation_dropout_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + # actually we don't test for dropout_p != 0.0 because forward functions will give + # different answers. This is because we are using the k2 implementation of + # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # internally, messing up the random state. + for dropout_p in [0.0]: + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) + with torch.no_grad(): + m2.weight[:] = m1[2].weight + if bias: + m2.bias[:] = m1[2].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + # TEMP. + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) + if bias: + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + + # the SwooshL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_piecewise_linear() + _test_softmax() + _test_whiten() + _test_balancer_sign() + _test_balancer_magnitude() + _test_swooshl_deriv() + _test_swooshr_deriv() + _test_activation_dropout_and_linear() + _test_double_swish_deriv() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py new file mode 100644 index 000000000..7acbc1808 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Tuple + +import torch +from scaling import ( + Balancer, + BiasNorm, + Dropout3, + FloatLike, + Optional, + ScaledConv2d, + ScaleGrad, + ScheduledFloat, + SwooshL, + SwooshR, + Whiten, +) +from torch import Tensor, nn + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + + def __init__( + self, + channels: int, + hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), + layerdrop_rate: FloatLike = None, + ): + super().__init__() + padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + hidden_channels = channels * hidden_ratio + if layerdrop_rate is None: + layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) + self.layerdrop_rate = layerdrop_rate + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=padding, + ) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1 + ) + + self.hidden_balancer = Balancer( + hidden_channels, + channel_dim=1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + self.activation = SwooshL() + self.pointwise_conv2 = ScaledConv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + initial_scale=0.01, + ) + + self.out_balancer = Balancer( + channels, + channel_dim=1, + min_positive=0.4, + max_positive=0.6, + min_abs=1.0, + max_abs=6.0, + ) + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or not self.training: + return self.forward_internal(x) + layerdrop_rate = float(self.layerdrop_rate) + + if layerdrop_rate != 0.0: + batch_size = x.shape[0] + mask = ( + torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) + > layerdrop_rate + ) + else: + mask = None + # turns out this caching idea does not work with --world-size > 1 + # return caching_eval(self.forward_internal, x, mask) + return self.forward_internal(x, mask) + + def forward_internal( + self, x: Tensor, layer_skip_mask: Optional[Tensor] = None + ) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + bypass = x + x = self.depthwise_conv(x) + x = self.pointwise_conv1(x) + x = self.hidden_balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + if layer_skip_mask is not None: + x = x * layer_skip_mask + + x = bypass + x + x = self.out_balancer(x) + x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last + x = self.out_whiten(x) + x = x.transpose(1, 3) # (N, C, H, W) + + return x + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: FloatLike = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-3)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite + """ + assert in_channels >= 7 + super().__init__() + + # The ScaleGrad module is there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # exceeding the range of fp16 when using automatic mixed precision (amp) + # training. (The second one is necessary to stop its bias from getting + # a too-large gradient). + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ScaleGrad(0.2), + Balancer(layer1_channels, channel_dim=1, max_abs=1.0), + SwooshR(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + Balancer(layer2_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + Balancer(layer3_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + ) + + # just one convnext layer + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + + out_width = (((in_channels - 1) // 2) - 1) // 2 + + self.out = nn.Linear(out_width * layer3_channels, out_channels) + # use a larger than normal grad_scale on this whitening module; there is + # only one such module, so there is not a concern about adding together + # many copies of this extra gradient term. + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), + prob=(0.025, 0.25), + grad_scale=0.02, + ) + + # max_log_eps=0.0 is to prevent both eps and the output of self.out from + # getting large, there is an unnecessary degree of freedom. + self.out_norm = BiasNorm(out_channels) + self.dropout = Dropout3(dropout, shared_dim=1) + + def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - output lengths, of shape (batch_size,) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) + # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite + # gradients. + x = self.conv(x) + x = self.convnext(x) + + # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_whiten(x) + x = self.out_norm(x) + x = self.dropout(x) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + assert x.size(1) == x_lens.max().item() + + return x, x_lens diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py b/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py new file mode 100755 index 000000000..13483637d --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless4/test_model.py +""" + +from scaling import ScheduledFloat +from train_subformer import get_params, get_text_encoder, get_transducer_model +from zipformer import Zipformer2 + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = 24 + params.dim_feedforward = 1536 # 384 * 4 + params.encoder_dim = 384 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf +def test_model_M(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,15,15" + + params.text_encoder_dim = (192, 192, 256, 384) + params.decoder_dim = 512 + params.joiner_dim = 512 + model = Zipformer2( + output_downsampling_factor=8, + downsampling_factor=(1, 2, 4, 8), + num_encoder_layers=(2, 4, 4, 4), + encoder_dim=(192, 192, 256, 384), + encoder_unmasked_dim=(192, 192, 256, 256), + query_head_dim=(32, 32, 32, 32), + pos_head_dim=(4, 4, 4, 4), + value_head_dim=(12, 12, 12, 12), + pos_dim=48, + num_heads=(4, 4, 4, 8), + feedforward_dim=( + 384, + 512, + 768, + 1024, + ), # could increase this if there is nough data + cnn_module_kernel=(31, 31, 15, 15), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=False, + ) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + model = Zipformer2( + output_downsampling_factor=8, + downsampling_factor=(1, 2, 4, 8), + num_encoder_layers=(2, 4, 6, 6), + encoder_dim=(256, 256, 384, 512), + encoder_unmasked_dim=(196, 196, 256, 256), + query_head_dim=(32, 32, 32, 32), + pos_head_dim=(4, 4, 4, 4), + value_head_dim=(12, 12, 12, 12), + pos_dim=48, + num_heads=(4, 4, 4, 8), + feedforward_dim=( + 384, + 512, + 768, + 1024, + ), # could increase this if there is nough data + cnn_module_kernel=(31, 31, 15, 15), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=False, + ) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + # test_model_1() + test_model_M() + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py new file mode 100644 index 000000000..efb4acc3c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py @@ -0,0 +1,101 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List + + +def train_text_normalization(s: str) -> str: + # replace full-width with half-width + s = s.replace("“", '"') + s = s.replace("”", '"') + s = s.replace("‘", "'") + s = s.replace("’", "'") + if s[:2] == '" ': # remove the starting double quote + s = s[2:] + + return s + + +def ref_text_normalization(ref_text: str) -> str: + # Rule 1: Remove the [FN#[]] + p = r"[FN#[0-9]*]" + pattern = re.compile(p) + + res = pattern.findall(ref_text) + ref_text = re.sub(p, "", ref_text) + + ref_text = train_text_normalization(ref_text) + + return ref_text + + +def remove_non_alphabetic(text: str, strict: bool = True) -> str: + # Recommend to set strict to False + if not strict: + # Note, this also keeps space, single quote(') and hypen (-) + text = text.replace("-", " ") + text = text.replace("—", " ") + return re.sub(r"[^a-zA-Z0-9\s']+", "", text) + else: + # only keeps space + return re.sub(r"[^a-zA-Z\s]+", "", text) + + +def upper_only_alpha(text: str) -> str: + return remove_non_alphabetic(text.upper(), strict=False) + + +def lower_only_alpha(text: str) -> str: + return remove_non_alphabetic(text.lower(), strict=False) + + +def lower_all_char(text: str) -> str: + return text.lower() + + +def upper_all_char(text: str) -> str: + return text.upper() + + +def _apply_style_transform(text: List[str], transform: str) -> List[str]: + """Apply transform to a list of text. By default, the text are in + ground truth format, i.e mixed-punc. + + Args: + text (List[str]): Input text string + transform (str): Transform to be applied + + Returns: + List[str]: _description_ + """ + if transform == "mixed-punc": + return text + elif transform == "upper-no-punc": + return [upper_only_alpha(s) for s in text] + elif transform == "lower-no-punc": + return [lower_only_alpha(s) for s in text] + elif transform == "lower-punc": + return [lower_all_char(s) for s in text] + else: + raise NotImplementedError(f"Unseen transform: {transform}") + + +if __name__ == "__main__": + ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." + print(ref_text) + res = upper_only_alpha(ref_text) + print(res) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py new file mode 100644 index 000000000..7075c9154 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -0,0 +1,1390 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + + +# For mix precision training: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# To train a streaming model + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --causal 1 + --exp-dir zipformer/exp \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_baseline import Transducer +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from text_normalization import train_text_normalization, upper_only_alpha +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_first( + texts: List[str], + pre_texts: List[str], + context_list: Optional[str] = None, + rare_word_list: Optional[List[str]] = None, +) -> str: + # Always get the first one, which is the mixed-cased text with punc + out = {"text": texts[0], "pre_text": pre_texts[0]} + return out + + +def get_upper_only_alpha( + texts: List[str], + pre_texts: List[str], + context_list: Optional[str] = None, + rare_word_list: Optional[List[str]] = None, +) -> str: + # Always get the first one, which is the mixed-cased text with punc, + # but with upper case it and remove punctuation + out = { + "text": upper_only_alpha(texts[0]), + "pre_text": upper_only_alpha(pre_texts[0]), + } + return out + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--text-encoder-dim", + type=str, + default="256,256,384,512", + help="Embedding dimension in text encoder stacks: a comma-separated list of 4 elements, " + "or you should change other configs in the code.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + texts = [train_text_normalization(t) for t in texts] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + if random.random() < 0.02: + logging.info(f"Ref texts: {texts[0]}") + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + libriheavy = LibriHeavyAsrDataModule(args) + + train_cuts = libriheavy.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 30.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].texts[0], out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].texts[0]}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + text_sampling_func = get_upper_only_alpha + logging.info(f"Text sampling func: {text_sampling_func}") + train_dl = libriheavy.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + text_sampling_func=text_sampling_func, + ) + + valid_cuts = libriheavy.dev_cuts() + valid_dl = libriheavy.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py new file mode 100755 index 000000000..e253d1118 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -0,0 +1,1798 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang, +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For mix precision training: + +(1) Non-streaming model, **without** context list + +./zipformer_prompt_asr/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --subset medium \ + --causal False \ + --exp-dir zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --memory-layer 0 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --use-style-prompt True \ + --use-context-list False + +(2) Non-streaming model, **with** context list + +./zipformer_prompt_asr/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --subset medium \ + --causal False \ + --exp-dir zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --memory-layer 0 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --use-style-prompt True \ + --use-context-list True \ + --top-k 10000 \ + --rare-word-file data/context_biasing/small_rare_words_topk_10000.txt + + +""" + + +import argparse +import copy +import logging +import os +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import numpy +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from dataset import ( + naive_triplet_text_sampling, + random_shuffle_subset, + triplet_text_sampling, + triplet_text_sampling_with_context_list, +) +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_with_BERT import PromptedTransducer +from optim import Eden, ScaledAdam +from scaling import Balancer, BiasNorm, Dropout3, ScaleGrad, ScheduledFloat, SwooshR +from subsampling import Conv2dSubsampling +from text_normalization import ( + lower_all_char, + lower_only_alpha, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + +style_transforms = [ + lambda x: x, # return it self + upper_only_alpha, + lower_only_alpha, + lower_all_char, +] + + +def get_first(texts: List[str], pre_texts: List[str]) -> str: + out = { + "text": texts[0], + "pre_text": pre_texts[0], + "style_text": "", + "transform_ids": 0, + } + return out + + +def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str: + # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha + out = { + "text": upper_only_alpha(texts[0]), + "pre_text": upper_only_alpha(pre_texts[0]), + "style_text": "", + "transform_ids": 0, + } + return out + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--memory-dropout-rate", + type=float, + default=0.05, + help="By which probability, dropout the memory when doing cross-attention.", + ) + + parser.add_argument( + "--memory-layer", + type=int, + default=0, + help="Start doing cross-attention from which layer. Zero-indexed", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--freeze-text-encoder", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--text-encoder-type", + type=str, + default="BERT", + choices=["BERT", "DistilBERT"], + help="Type of the text encoder", + ) + + parser.add_argument( + "--text-encoder-dim", + type=int, + default=768, + help="Dimension of the text encoder", + ) + + parser.add_argument( + "--text-encoder-adapter", + type=str2bool, + default=False, + help="An adapter for pre-trained BERT", + ) + + parser.add_argument( + "--context-injection", + type=str2bool, + default=False, + help="Inject context embedding into the joiner", + ) + + parser.add_argument( + "--context-dropout-rate", + type=float, + default=0.05, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Whether to use style prompt.", + ) + + # arguments for using prompt + parser.add_argument( + "--pre-text-shuffle-prob", + type=float, + default=0.05, + help="The proportion of pre_text to be shuffled with in a batch", + ) + + parser.add_argument( + "--style-text-shuffle-prob", + type=float, + default=0.2, + help="The proportion of style_text to be shuffled with in a batch", + ) + + parser.add_argument( + "--prompt-mask-prob", + type=float, + default=0.05, + help="The probability of masking prompts", + ) + + parser.add_argument( + "--forced-upper-pre-text", + type=str2bool, + default=False, + help="Forced format of pre-text", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +class TextEmbedding(nn.Module): + def __init__( + self, + num_embeddings: int = 256, + embedding_dim: int = 256, + kernel_size: int = 3, + layer1_channels: int = 256, + layer2_channels: int = 256, + bias: bool = True, + dropout: float = 0.1, + ): + super().__init__() + self.embed = nn.Embedding( + num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes + embedding_dim=embedding_dim, # + ) + + assert embedding_dim == layer1_channels # for depth wise convolution + self.conv = nn.Sequential( + nn.Conv1d( + embedding_dim, + layer1_channels, # depthwise convolution + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=layer1_channels, + bias=True, + ), + ScaleGrad(0.2), + Balancer(layer1_channels, channel_dim=1, min_positive=0.1, max_abs=1.0), + nn.ReLU(), + nn.Conv1d( + layer1_channels, + layer2_channels, + kernel_size=1, # pointwise convolution + stride=1, + padding=0, + bias=True, + ), + Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0), + nn.ReLU(), + ) + + self.out_norm = BiasNorm(layer2_channels) + self.dropout = Dropout3(dropout, shared_dim=1) + + def forward(self, text: torch.Tensor) -> torch.Tensor: + """Forward function of the text embedding + + Args: + text (torch.Tensor): Text in UTF-8 bytes (T,N) + Returns: + The embeddings of text (T,N,C) + """ + text = self.embed(text) # (T,N,C) + + # src = text + text = text.permute(1, 2, 0) # (T,N,C) -> (N,C,T) + text = self.conv(text) + text = text.permute(2, 0, 1) # (N,C,T) -> (T,N,C) + # src = src + text + + text = self.out_norm(text) + text = self.dropout(text) + + return text + + +def get_text_encoder(params: AttributeDict) -> nn.Module: + # Return a text encoder + if params.text_encoder_type == "BERT": # This is a BERT-base-cased + from transformers import BertModel + + logging.info("Loading pre-trained BERT-base-cased as text encoder") + if os.path.exists("data/models/bert-base-cased"): + model = BertModel.from_pretrained("data/models/bert-base-cased") + else: + model = BertModel.from_pretrained("bert-base-cased") + assert params.text_encoder_dim == 768 + elif params.text_encoder_type == "BERT-large": + from transformers import BertModel + + logging.info("Loading pre-trained BERT-large-uncased as text encoder") + if os.path.exists("data/models/bert-large-uncased"): + model = BertModel.from_pretrained("data/models/bert-large-uncased") + else: + model = BertModel.from_pretrained("bert-large-uncased") + assert params.text_encoder_dim == 1024 + elif params.text_encoder_type == "DistilBERT": + from transformers import DistilBertModel # This is a DistilBERT-base-cased + + logging.info("Loading pre-trained DistilBERT-base-cased as text encoder") + model = DistilBertModel.from_pretrained("distilbert-base-cased") + assert params.text_encoder_dim == 768 + else: + raise ValueError() + + return model + + +def get_tokenizer(params: AttributeDict): + + if params.text_encoder_type == "BERT": + from transformers import BertTokenizer + + # This is a BERT-base-cased + if os.path.exists("data/models/bert-base-cased"): + tokenizer = BertTokenizer.from_pretrained("data/models/bert-base-cased") + else: + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + elif params.text_encoder_type == "BERT-large": + from transformers import BertTokenizer + + # This is a BERT-large-uncased + if os.path.exists("data/models/bert-large-uncased"): + tokenizer = BertTokenizer.from_pretrained("data/models/bert-large-uncased") + else: + tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") + elif params.text_encoder_type == "DistilBERT": + from transformers import DistilBertTokenizer + + tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased") + else: + raise ValueError() + + return tokenizer + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + memory_dim=params.text_encoder_dim, # This is fixed as the BERT base model is 768-D + memory_layer=params.memory_layer, + memory_dropout_rate=params.memory_dropout_rate, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + context_dim=4 * 768 + if params.context_injection + else -1, # the output dim of text encoder + context_injection=params.context_injection, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + text_encoder = get_text_encoder(params) # This should be a cased BERT base model + num_param = sum([p.numel() for p in text_encoder.parameters()]) + logging.info(f"Num params in text encoder: {num_param}") + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = PromptedTransducer( + encoder_embed=encoder_embed, + encoder=encoder, + text_encoder=text_encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + text_encoder_type=params.text_encoder_type, + text_encoder_adapter=params.text_encoder_adapter, + freeze_text_encoder=params.freeze_text_encoder, + context_fuser=None, + ) + + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def _encode_texts_as_bytes_with_tokenizer( + pre_texts: List[str], + style_texts: List[str], + tokenizer, + device: torch.device, + max_len: int = 500, + no_limit: bool = False, +) -> Tuple[Dict, Tensor]: + """ + Encode texts as bytes and then integer tensors. + Note that the style text will be added to the beginning of texts. + """ + batch_size = len(pre_texts) + max_len = min(max_len, 500) + + if no_limit: + allowed_lens = [5000 - len(s) for s in style_texts] + else: + allowed_lens = [1000 - len(s) for s in style_texts] + truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)] + combined_text = [ + style_texts[i] + " [SEP] " + truncated_pre_texts[i] for i in range(batch_size) + ] + + encoded_style_texts = tokenizer( + style_texts, + return_tensors="pt", + padding=True, + truncation=True, + return_length=True, + max_length=max_len, + ) + style_lens = encoded_style_texts["length"].to(device) + + # Use tokenizer to prepare input for text encoder + encoded_inputs = tokenizer( + combined_text, + return_tensors="pt", + padding=True, + truncation=True, + return_length=True, + max_length=max_len, + ).to(device) + + return encoded_inputs, style_lens + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + batch_size = feature.size(0) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + pre_texts = batch["supervisions"]["pre_text"] + style_texts = batch["supervisions"][ + "style_text" + ] # the style texts are in gt format + transform_ids = batch["supervisions"]["transform_ids"] + + # This is to replace full-width symbols with half-width symbols + texts = [train_text_normalization(t) for t in texts] + pre_texts = [train_text_normalization(t) for t in pre_texts] + style_texts = [train_text_normalization(t) for t in style_texts] + + y = sp.encode( + texts, out_type=int + ) # sp.encode treats consecutive space as a single space + y = k2.RaggedTensor(y).to(device) + + if params.forced_upper_pre_text: + pre_texts = [upper_only_alpha(p) for p in pre_texts] + + # only shuffle the pre_text and style texts if during training, and use style prompt + if is_training: + # randomly shuffle&mask the pre_text + pre_texts = random_shuffle_subset( + pre_texts, + p=params.pre_text_shuffle_prob, + p_mask=params.prompt_mask_prob, + ) + + if params.use_style_prompt: + if random.random() < 0.5: + # randomly shuffle the style_text + # now the style_texts are all in gt format + style_texts = random_shuffle_subset( + style_texts, + p=params.style_text_shuffle_prob, + p_mask=params.prompt_mask_prob, + ) + + assert len(transform_ids) == len(style_texts) + + for i in range(len(style_texts)): + t = transform_ids[i] # get the transform id + style_texts[i] = style_transforms[t](style_texts[i]) + + if not params.use_style_prompt: + style_texts = [ + "" for _ in style_texts + ] # use empty string for style texts if don't use style prompt + + if random.random() < 0.05: + logging.info(f"Pre texts: {pre_texts[0]}") + logging.info(f"Ref texts: {texts[0]}") + logging.info(f"Style texts: {style_texts[0]}") + + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_texts, + style_texts=style_texts, + tokenizer=tokenizer, + device=device, + ) + + if random.random() < 0.02: + logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ") + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + encoded_inputs=encoded_inputs, + style_lens=style_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if not params.use_style_prompt: + if params.pre_text_shuffle_prob == 0.0: + logging.info( + f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}" + ) + logging.info( + "If style prompt is not used, you should be careful when shuffling the pre_text within the same batch" + ) + logging.info("Hard set this probability to 0.0!") + params.pre_text_shuffle_prob = 0.0 + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.freeze_text_encoder: + freeze_modules = ["text_encoder"] + logging.info( + "Freeze the parameters of text encoder and don't include them in the optimizer" + ) + else: + freeze_modules = [] + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules + ), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + args.max_duration = 100 + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + libriheavy = LibriHeavyAsrDataModule(args) + + train_cuts = libriheavy.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 30.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].texts[0], out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].texts[0]}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + if params.use_context_list: + text_sampling_func = triplet_text_sampling_with_context_list + else: + text_sampling_func = triplet_text_sampling + + logging.info(f"Text sampling: {text_sampling_func}") + + train_dl = libriheavy.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + text_sampling_func=text_sampling_func, + ) + + # For fair comparison, use fixed sampling in valid dataloaders + valid_cuts = libriheavy.dev_cuts() + valid_dl = libriheavy.valid_dataloaders( + valid_cuts, text_sampling_func=naive_triplet_text_sampling + ) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + tokenizer=tokenizer, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + tokenizer=tokenizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + tokenizer: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py new file mode 100644 index 000000000..ef0c48e8a --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py @@ -0,0 +1,515 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +python ./zipformer_prompt_asr/transcribe_bert.py \ + --epoch 50 \ + --avg 10 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/long_audios/long_audio.jsonl.gz \ + --pre-text-transform mixed-punc \ + --style-text-transform mixed-punc \ + --num-history 5 \ + --use-pre-text True \ + --use-gt-pre-text False + + +""" + +import argparse +import logging +import math +import warnings +from pathlib import Path +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from decode_bert import _apply_style_transform +from lhotse import Fbank, load_manifest +from text_normalization import ( + lower_all_char, + lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from tqdm import tqdm +from train_bert_encoder import ( + _encode_texts_as_bytes_with_tokenizer, + add_model_arguments, + get_params, + get_tokenizer, + get_transducer_model, +) + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/long_audios/long_audio.jsonl.gz", + help="""This is the manfiest for long audio transcription. + The cust are intended to be sorted, i.e first sort by recording ID and + then sort by start timestamp""", + ) + + parser.add_argument( + "--use-pre-text", + type=str2bool, + default=False, + help="Whether use pre-text when decoding the current chunk", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Use style prompt when evaluation", + ) + + parser.add_argument( + "--pre-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of content prompt, i.e pre_text", + ) + + parser.add_argument( + "--style-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of style prompt, i.e style_text", + ) + + parser.add_argument( + "--num-history", + type=int, + default=2, + help="How many previous chunks to look if using pre-text for decoding", + ) + + parser.add_argument( + "--use-gt-pre-text", + type=str2bool, + default=False, + help="Whether use gt pre text when using content prompt", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=True, + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + params.res_dir = params.exp_dir / "long_audio_transcribe" + params.res_dir.mkdir(exist_ok=True) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "beam_search" in params.method: + params.suffix += f"-{params.method}-beam-size-{params.beam_size}" + + if params.use_pre_text: + if params.use_gt_pre_text: + params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}" + else: + params.suffix += ( + f"-pre-text-{params.pre_text_transform}-history-{params.num_history}" + ) + + book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "") + setup_logger( + f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info" + ) + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + # load manifest + manifest = load_manifest(params.manifest_dir) + + results = [] + count = 0 + + last_recording = "" + last_end = -1 + history = [] + num_pre_texts = [] + + for cut in manifest: + if cut.has_features: + feat = cut.load_features() + feat_lens = cut.num_frames + else: + feat = cut.compute_features(extractor=Fbank()) + feat_lens = feat.shape[0] + + cur_recording = cut.recording.id + + if cur_recording != last_recording: + last_recording = cur_recording + history = [] # clean up the history + last_end = -1 + logging.info("Moving on to the next recording") + else: + if cut.start < last_end - 0.2: # overlap with the previous cuts + logging.warning("An overlap exists between current cut and last cut") + logging.warning("Skipping this cut!") + continue + if cut.start > last_end + 10: + logging.warning( + f"Large time gap between the current and previous utterance: {cut.start - last_end}." + ) + + # prepare input + x = torch.tensor(feat, device=device).unsqueeze(0) + x_lens = torch.tensor( + [ + feat_lens, + ], + device=device, + ) + + if params.use_pre_text: + if params.num_history > 0: + pre_texts = history[-params.num_history :] + else: + pre_texts = [] + num_pre_texts.append(len(pre_texts)) + pre_texts = [train_text_normalization(" ".join(pre_texts))] + fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." + style_texts = [fixed_sentence] + + pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) + if params.use_style_prompt: + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + + # encode prompts + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_texts, + style_texts=style_texts, + tokenizer=tokenizer, + device=device, + no_limit=True, + ) + if params.num_history > 5: + logging.info( + f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} " + ) + + memory, memory_key_padding_mask = model.encode_text( + encoded_inputs=encoded_inputs, + style_lens=style_lens, + ) # (T,B,C) + else: + memory = None + memory_key_padding_mask = None + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=x, + feature_lens=x_lens, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + if params.method == "greedy_search": + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + hyp = sp.decode(hyp_tokens)[0] # in string format + ref_text = ref_text_normalization( + cut.supervisions[0].texts[0] + ) # required to match the training + + # extend the history + if params.use_gt_pre_text: + history.append(ref_text) + else: + history.append(hyp) + last_end = cut.end # update the last end timestamp + + # append the current decoding result + hyp = hyp.split() + ref = ref_text.split() + results.append((cut.id, ref, hyp)) + + count += 1 + if count % 100 == 0: + logging.info(f"Cuts processed until now: {count}/{len(manifest)}") + logging.info( + f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}" + ) + + logging.info(f"A total of {count} cuts") + logging.info( + f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}" + ) + + results = sorted(results) + recog_path = ( + params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + errs_filename = ( + params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"long-audio-{params.method}", + results, + enable_log=True, + compute_CER=False, + ) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + if params.post_normalization: + params.suffix += "-post-normalization" + + new_res = [] + for item in results: + id, ref, hyp = item + hyp = upper_only_alpha(" ".join(hyp)).split() + ref = upper_only_alpha(" ".join(ref)).split() + new_res.append((id, ref, hyp)) + + new_res = sorted(new_res) + recog_path = ( + params.res_dir + / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" + ) + store_transcripts(filename=recog_path, texts=new_res) + logging.info(f"The transcripts are stored in {recog_path}") + + errs_filename = ( + params.res_dir + / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"long-audio-{params.method}", + new_res, + enable_log=True, + compute_CER=False, + ) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py new file mode 100644 index 000000000..533982519 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py @@ -0,0 +1,439 @@ +import argparse +import ast +import glob +import logging +import os +from collections import defaultdict +from typing import Dict, Iterable, List, TextIO, Tuple, Union + +import kaldialign +from lhotse import load_manifest, load_manifest_lazy +from lhotse.cut import Cut, CutSet +from text_normalization import remove_non_alphabetic +from tqdm import tqdm + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/fbank", + help="Where are the manifest stored", + ) + + parser.add_argument( + "--subset", type=str, default="medium", help="Which subset to work with" + ) + + parser.add_argument( + "--top-k", + type=int, + default=10000, + help="How many words to keep", + ) + + return parser + + +def get_facebook_biasing_list( + test_set: str, + num_distractors: int = 100, +) -> Dict: + # Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf + assert num_distractors in (0, 100, 500, 1000, 2000), num_distractors + if num_distractors == 0: + if test_set == "test-clean": + biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv" + elif test_set == "test-other": + biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv" + else: + raise ValueError(f"Unseen test set {test_set}") + else: + if test_set == "test-clean": + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" + elif test_set == "test-other": + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" + else: + raise ValueError(f"Unseen test set {test_set}") + + f = open(biasing_file, "r") + data = f.readlines() + f.close() + + output = dict() + for line in data: + id, _, l1, l2 = line.split("\t") + if num_distractors > 0: # use distractors + biasing_list = ast.literal_eval(l2) + else: + biasing_list = ast.literal_eval(l1) + biasing_list = [w.strip().upper() for w in biasing_list] + output[id] = " ".join(biasing_list) + + return output + + +def brian_biasing_list(level: str): + # The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf + root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level" + all_files = glob.glob(root_dir + "/*") + biasing_dict = {} + for f in all_files: + k = f.split("/")[-1] + fin = open(f, "r") + data = fin.read().strip().split() + biasing_dict[k] = " ".join(data) + fin.close() + + return biasing_dict + + +def get_rare_words( + subset: str = "medium", + top_k: int = 10000, + # min_count: int = 10000, +): + """Get a list of rare words appearing less than `min_count` times + + Args: + subset: The dataset + top_k (int): How many frequent words + """ + txt_path = f"data/tmp/transcript_words_{subset}.txt" + rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt" + + if os.path.exists(rare_word_file): + print("File exists, do not proceed!") + return + + print("---Identifying rare words in the manifest---") + count_file = f"data/tmp/transcript_words_{subset}_count.txt" + if not os.path.exists(count_file): + with open(txt_path, "r") as file: + words = file.read().upper().split() + word_count = {} + for word in words: + word = remove_non_alphabetic(word, strict=False) + word = word.split() + for w in word: + if w not in word_count: + word_count[w] = 1 + else: + word_count[w] += 1 + + word_count = list(word_count.items()) # convert to a list of tuple + word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) + with open(count_file, "w") as fout: + for w, count in word_count: + fout.write(f"{w}\t{count}\n") + + else: + word_count = {} + with open(count_file, "r") as fin: + word_count = fin.read().strip().split("\n") + word_count = [pair.split("\t") for pair in word_count] + word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) + + print(f"A total of {len(word_count)} words appeared!") + rare_words = [] + for word, count in word_count[top_k:]: + rare_words.append(word + "\n") + print(f"A total of {len(rare_words)} are identified as rare words.") + + with open(rare_word_file, "w") as f: + f.writelines(rare_words) + + +def add_context_list_to_manifest( + manifest_dir: str, + subset: str = "medium", + top_k: int = 10000, +): + """Generate a context list of rare words for each utterance in the manifest + + Args: + manifest_dir: Where to store the manifest with context list + subset (str): Subset + top_k (int): How many frequent words + + """ + orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz" + target_manifest_dir = orig_manifest_dir.replace( + ".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz" + ) + if os.path.exists(target_manifest_dir): + print(f"Target file exits at {target_manifest_dir}!") + return + + rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt" + print(f"---Reading rare words from {rare_words_file}---") + with open(rare_words_file, "r") as f: + rare_words = f.read() + rare_words = rare_words.split("\n") + rare_words = set(rare_words) + print(f"A total of {len(rare_words)} rare words!") + + cuts = load_manifest_lazy(orig_manifest_dir) + print(f"Loaded manifest from {orig_manifest_dir}") + + def _add_context(c: Cut): + splits = ( + remove_non_alphabetic(c.supervisions[0].texts[0], strict=False) + .upper() + .split() + ) + found = [] + for w in splits: + if w in rare_words: + found.append(w) + c.supervisions[0].context_list = " ".join(found) + return c + + cuts = cuts.map(_add_context) + print(f"---Saving manifest with context list to {target_manifest_dir}---") + cuts.to_file(target_manifest_dir) + print("Finished") + + +def check( + manifest_dir: str, + subset: str = "medium", + top_k: int = 10000, +): + # Show how many samples in the training set have a context list + # and the average length of context list + print("--- Calculating the stats over the manifest ---") + + manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz" + cuts = load_manifest_lazy(manifest_dir) + total_cuts = len(cuts) + has_context_list = [c.supervisions[0].context_list != "" for c in cuts] + context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts] + print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ") + print( + f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}" + ) + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, + compute_CER: bool = False, + biasing_words: List[str] = None, +) -> float: + """Write statistics based on predicted results and reference transcripts. It also calculates the + biasing word error rate as described in https://arxiv.org/pdf/2104.02194.pdf + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cut_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + biasing_words: + All the words in the biasing list + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + for ref_word, hyp_word in ali: + if ref_word == ERR: # INSERTION + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: # DELETION + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: # SUBSTITUTION + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + unbiased_word_counts = 0 + unbiased_word_errs = 0 + biased_word_counts = 0 + biased_word_errs = 0 + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + # number of appearances of "word" in reference text + ref_count = ( + corr + ref_sub + dels + ) # correct + in ref but got substituted + deleted + # number of appearances of "word" in hyp text + hyp_count = corr + hyp_sub + ins + + if biasing_words is not None: + if word in biasing_words: + biased_word_counts += ref_count + biased_word_errs += ins + dels + ref_sub + else: + unbiased_word_counts += ref_count + unbiased_word_errs += ins + dels + hyp_sub + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + + if biasing_words is not None: + B_WER = "%.2f" % (100 * biased_word_errs / biased_word_counts) + U_WER = "%.2f" % (100 * unbiased_word_errs / unbiased_word_counts) + logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ") + logging.info( + f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]" + ) + + return float(tot_err_rate) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + manifest_dir = args.manifest_dir + subset = args.subset + top_k = args.top_k + get_rare_words(subset=subset, top_k=top_k) + add_context_list_to_manifest( + manifest_dir=manifest_dir, + subset=subset, + top_k=top_k, + ) + check( + manifest_dir=manifest_dir, + subset=subset, + top_k=top_k, + ) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py new file mode 100644 index 000000000..d1cf90ffb --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py @@ -0,0 +1,2310 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, + Balancer, + BiasNorm, + ChunkCausalDepthwiseConv1d, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + memory_dim: if supplied and >0, will be the dimension of the memory embeddings + passed into the zipformer (e.g. this might be the output of another + Zipformer used to create embedding vectors.) + memory_dropout_rate: By this probability, do not use the provided memory for + cross-attention. This should give robustness to the model when evaluated + without memory. + memory_layer: if supplied and >0, only add cross-attention module starting from + the specified layer. + """ + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + memory_dim: int = -1, + memory_dropout_rate: float = 0.05, + memory_layer: int = 0, + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + self.memory_dropout_rate = memory_dropout_rate + self.memory_layer = memory_layer + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + memory_dim=memory_dim if i >= self.memory_layer else -1, + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + if self.training and memory is not None: + batch_size = x.shape[1] + # setting memory to zero should be equivalent to not using the + # memory input at all, since the Attention module has no biases. + memory = memory * ( + torch.rand(batch_size, 1, device=memory.device) + > self.memory_dropout_rate + ) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + memory=memory if i >= self.memory_layer else None, + memory_key_padding_mask=memory_key_padding_mask + if i >= self.memory_layer + else None, + ) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + memory_dim: int = -1, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = Attention(embed_dim, embed_dim, num_heads, value_head_dim) + + self.self_attn2 = Attention(embed_dim, embed_dim, num_heads, value_head_dim) + + if memory_dim > 0: + self.attn_weights = MultiheadAttentionWeights( + memory_dim, + embed_dim, + num_heads=num_heads, + head_dim=query_head_dim, + dropout=0.0, + ) + self.src_attn1 = Attention(memory_dim, embed_dim, num_heads, value_head_dim) + self.src_attn2 = Attention(memory_dim, embed_dim, num_heads, value_head_dim) + self.memory_balancer = Balancer( + embed_dim, + channel_dim=-1, + min_abs=0.015, + ) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + # self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) + + self.norm = BiasNorm(embed_dim) + + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, + min=float(self.bypass_min), + max=float(self.bypass_max), + ) + layer_skip_rate = float(self.layer_skip_rate) + if layer_skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + return ans + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + if memory is not None and hasattr(self, "attn_weights"): + src_attn_weights = self.attn_weights(memory, src, memory_key_padding_mask) + + src = src + self.feed_forward1(src) + + attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + + if True: + selected_attn_weights = attn_weights[0:2] + if random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights[0:1])) + + src = src + (na if attn_dropout_mask is None else na * attn_dropout_mask) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask + ) + + if memory is not None and hasattr(self, "attn_weights"): + src = src + self.sequence_dropout( + self.memory_balancer(self.src_attn1(memory, src_attn_weights)), + attention_skip_rate, + ) + + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + float(self.conv_skip_rate), + ) + + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), float(self.ff2_skip_rate) + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask + ) + + if memory is not None and hasattr(self, "attn_weights"): + src = src + self.sequence_dropout( + self.memory_balancer(self.src_attn2(memory, src_attn_weights)), + attention_skip_rate, + ) + + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + float(self.conv_skip_rate), + ) + + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), float(self.ff3_skip_rate) + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + output = output * feature_mask + + return output + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, downsample, dropout) + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0.025) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= x.size(0) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + T = x.size(0) + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> Tensor: + """Create positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + + Returns: + positional embedding, of shape (1, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x.size(0) + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim) + chunk_size + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + if not self.training or random.random() >= float(self.pos_emb_skip_rate): + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if self.training and random.random() < 0.1: + # This is away of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 25.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class Attention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim_in: the input embedding dimension + embed_dim_out: the output embedding dimension (normally the same as input) + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim_in: int, + embed_dim_out: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim_in, num_heads * value_head_dim, bias=False) + + # Note we set bias to False so that input of 0 will have no effect + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim_out, bias=False, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, query_len, key_len), + Expect attn_weights.sum(dim=-1) == 1. The input here is the value in the + original attention mechanism. + Returns: + a tensor with the same shape as x. + """ + (num_heads, batch_size, query_len, key_len) = attn_weights.shape + + x = self.in_proj(x) # (key_len, batch_size, num_heads * value_head_dim) + x = x.reshape(key_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, key_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, query_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(query_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (query_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class MultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head cross-attention weights. Allows src and target + to have different dims. + + Args: + key_embed_dim: number of channels of the thing that we'll project to + make the query (corresponds to source). e.g. 256 + query_embed_dim: number of channels of the thing that we'll project to + make the query (corresponds to target). e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + head_dim: dimension of the query and key, per head. e.g. 24. + dropout: dropout probability for attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + key_embed_dim: int, + query_embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.key_embed_dim = key_embed_dim + self.query_embed_dim = query_embed_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.query_in_proj = ScaledLinear( + query_embed_dim, + head_dim * num_heads, + bias=True, + initial_scale=head_dim**-0.25, + ) + + # weights produced by this module are invariant to adding a constant to + # the keys, so we don't need a bias for the keys. + self.key_in_proj = ScaledLinear( + key_embed_dim, + head_dim * num_heads, + bias=False, + initial_scale=head_dim**-0.25, + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + def forward( + self, + key: Tensor, + query: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + key: input of shape (key_len, batch_size, key_embed_dim) + query: input of shape (query_len, batch_size, query_embed_dim) + key_padding_mask: an optional bool tensor of shape (batch_size, key_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, query_len, key_len) + """ + q = self.query_in_proj(query) + k = self.key_in_proj(key) + + head_dim = self.head_dim + num_heads = self.num_heads + + query_len, batch_size, _ = q.shape + key_len, _batch_size, _ = k.shape + assert _batch_size == batch_size + + k = self.whiten_keys(k) # does nothing in the forward pass. + + q = q.reshape(query_len, batch_size, num_heads, head_dim) + k = k.reshape(key_len, batch_size, num_heads, head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + if self.training and random.random() < 0.1: + # This is a way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 25.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, query_len, key_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + key_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in + # the range 1 to 4, but sometimes, for some reason, for layer 0 the + # rms ends up being very large, between 50 and 100 for different channels. + # This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=-1) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + memory_dim = 100 + + c = Zipformer2( + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + memory_dim=memory_dim, + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + memory=torch.randn(101, batch_size, memory_dim), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) diff --git a/icefall/utils.py b/icefall/utils.py index 8fda3a4ca..410340d9d 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -483,7 +483,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: def store_transcripts( - filename: Pathlike, texts: Iterable[Tuple[str, str, str]] + filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False ) -> None: """Save predicted results and reference transcripts to a file. @@ -500,6 +500,9 @@ def store_transcripts( """ with open(filename, "w") as f: for cut_id, ref, hyp in texts: + if char_level: + ref = list("".join(ref)) + hyp = list("".join(hyp)) print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) @@ -557,6 +560,7 @@ def write_error_stats( test_set_name: str, results: List[Tuple[str, str]], enable_log: bool = True, + compute_CER: bool = False, sclite_mode: bool = False, ) -> float: """Write statistics based on predicted results and reference transcripts. @@ -585,7 +589,7 @@ def write_error_stats( The reference word `SIR` is missing in the predicted results (a deletion error). results: - An iterable of tuples. The first element is the cur_id, the second is + An iterable of tuples. The first element is the cut_id, the second is the reference transcript and the third element is the predicted result. enable_log: If True, also print detailed WER to the console. @@ -602,6 +606,14 @@ def write_error_stats( words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) num_corr = 0 ERR = "*" + + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) for ref_word, hyp_word in ali: @@ -1426,7 +1438,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa def get_parameter_groups_with_lrs( - model: nn.Module, lr: float, include_names: bool = False + model: nn.Module, + lr: float, + include_names: bool = False, + freeze_modules: List[str] = [], ) -> List[dict]: """ This is for use with the ScaledAdam optimizers (more recent versions that accept lists of @@ -1450,6 +1465,8 @@ def get_parameter_groups_with_lrs( ... ] """ + named_modules = list(model.named_modules()) + # flat_lr_scale just contains the lr_scale explicitly specified # for each prefix of the name, e.g. 'encoder.layers.3', these need # to be multiplied for all prefix of the name of any given parameter. @@ -1469,6 +1486,15 @@ def get_parameter_groups_with_lrs( split_name = name.split(".") # caution: as a special case, if the name is '', split_name will be [ '' ]. prefix = split_name[0] + if prefix == "module": # DDP + module_name = split_name[1] + if module_name in freeze_modules: + logging.info(f"Remove {name} from parameters") + continue + else: + if prefix in freeze_modules: + logging.info(f"Remove {name} from parameters") + continue cur_lr = lr * flat_lr_scale[prefix] if prefix != "": cur_lr *= flat_lr_scale[""] From 2b3c5d799f3a585dc22071a9148424ff77aefd47 Mon Sep 17 00:00:00 2001 From: Wen Ding Date: Wed, 11 Oct 2023 16:58:00 +0800 Subject: [PATCH 009/216] Fix padding issues (#1303) --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bcd419fb7..ab46e233b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -158,7 +158,7 @@ class Conformer(EncoderInterface): if not is_jit_tracing(): assert x.size(0) == lengths.max().item() - src_key_padding_mask = make_pad_mask(lengths) + src_key_padding_mask = make_pad_mask(lengths, x.size(0)) if self.dynamic_chunk_training: assert ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5b75b8d35..cbde2a2e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -281,7 +281,7 @@ class Zipformer(EncoderInterface): lengths = (x_lens - 7) >> 1 assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) + mask = make_pad_mask(lengths, x.size(0)) outputs = [] feature_masks = self.get_feature_masks(x) From 855492156a3c84bca67870d808d033fe963f16bf Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 12 Oct 2023 16:48:23 +0800 Subject: [PATCH 010/216] Update finetune.py (#1304) --- egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index 82bc882bd..c943a84af 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -734,7 +734,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: From 162ceaf4b3110d452b5fed337d721c046d7787fa Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 12 Oct 2023 17:05:41 +0800 Subject: [PATCH 011/216] fixes for data preparation (#1307) Issue: #1306 --- egs/aishell/ASR/prepare.sh | 11 +++++++---- egs/librispeech/ASR/prepare.sh | 14 ++++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 9de060e73..d5dbe5726 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -204,10 +204,6 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py --lang-dir $lang_char_dir fi - - if [ ! -f $lang_char_dir/HLG.fst ]; then - ./local/prepare_lang_fst.py --lang-dir $lang_phone_dir --ngram-G ./data/lm/G_3_gram.fst.txt - fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then @@ -262,6 +258,13 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then --max-order=3 \ data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt fi + + if [ ! -f $lang_char_dir/HLG.fst ]; then + lang_phone_dir=data/lang_phone + ./local/prepare_lang_fst.py \ + --lang-dir $lang_phone_dir \ + --ngram-G ./data/lm/G_3_gram.fst.txt + fi fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 93d010ea8..739608572 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -242,10 +242,6 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then $lang_dir/L_disambig.pt \ $lang_dir/L_disambig.fst fi - - if [ ! -f $lang_dir/HL.fst ]; then - ./local/prepare_lang_fst.py --lang-dir $lang_dir --ngram-G ./data/lm/G_3_gram.fst.txt - fi done fi @@ -303,6 +299,16 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then --max-order=4 \ $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_dir \ + --ngram-G ./data/lm/G_3_gram.fst.txt + fi + done fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then From eeeeef390b2d7f1aefe742ac069565d5f8eb8a38 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 12 Oct 2023 22:02:49 +0800 Subject: [PATCH 012/216] Minor bug fixes and descriptive text for the `LibriCSS` recipe (#1268) --- egs/libricss/SURT/prepare.sh | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/egs/libricss/SURT/prepare.sh b/egs/libricss/SURT/prepare.sh index 3d2581d96..b2d37f949 100755 --- a/egs/libricss/SURT/prepare.sh +++ b/egs/libricss/SURT/prepare.sh @@ -90,6 +90,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # NOTE: Alignments are required for this recipe. mkdir -p data/manifests + log "This recipe uses mfa alignment for trimming" + if [ ! -d $dl_dir/libri_alignments/LibriSpeech ]; then + log "No alignment provided. please refer to ../../librispeech/ASR/add_alignments.sh \n \ + for mfa alignments. Once you have downloaded and unzipped the .zip file containing \n \ + all alignments, the folder should be renamed to libri_alignments and moved to your $dl_dir ." + exit 0 + fi + lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \ -j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/ fi @@ -118,9 +126,12 @@ fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Extract features for LibriSpeech, trim to alignments, and shuffle the cuts" - python local/compute_fbank_librispeech.py - lhotse combine data/manifests/librispeech_cuts_train* - |\ - lhotse cut trim-to-alignments --type word --max-pause 0.2 - - |\ + # python local/compute_fbank_librispeech.py + lhotse combine data/manifests/librispeech_cuts_train* data/manifests/librispeech_cuts_train_all.jsonl.gz + lhotse cut trim-to-alignments --type word --max-pause 0.2 \ + data/manifests/librispeech_cuts_train_all.jsonl.gz \ + data/manifests/librispeech_cuts_train_all_trimmed.jsonl.gz + cat <(gunzip -c data/manifests/librispeech_cuts_train_all_trimmed.jsonl.gz) | \ shuf | gzip -c > data/manifests/librispeech_cuts_train_trimmed.jsonl.gz fi @@ -152,7 +163,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then data/manifests/lsmix_cuts_train_clean_ov40.jsonl.gz # Full training set (2,3 speakers) anechoic - log "Generating anechoic ${part} set (full)" + log "Generating anechoic set (full)" lhotse workflows simulate-meetings \ --method conversational \ --fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz \ From 1ef349d120acef5d48feee58c4462a56f4a8c995 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 16 Oct 2023 16:28:16 +0800 Subject: [PATCH 013/216] [WIP] AISHELL-1 pruned transducer stateless7 streaming recipe (#1300) * `pruned_transudcer_stateless7_streaming` for AISHELL-1 * Update train.py * Update train2.py * Update decode.py * Update RESULTS.md --- egs/aishell/ASR/RESULTS.md | 50 + .../README.md | 1 + .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 735 ++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export-for-ncnn-zh.py | 1 + .../export-for-ncnn.py | 1 + .../export-onnx-zh.py | 1 + .../export-onnx.py | 1 + .../export.py | 1 + .../jit_pretrained.py | 1 + .../jit_trace_export.py | 1 + .../jit_trace_pretrained.py | 1 + .../joiner.py | 1 + .../model.py | 1 + .../ncnn_custom_layer.py | 1 + .../onnx_check.py | 1 + .../onnx_model_wrapper.py | 1 + .../onnx_pretrained.py | 1 + .../optim.py | 1 + .../pretrained.py | 1 + .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming-ncnn-decode.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 627 +++++++++ .../test_model.py | 1 + .../train.py | 1251 ++++++++++++++++ .../train2.py | 1253 +++++++++++++++++ .../zipformer.py | 1 + .../zipformer2.py | 1 + 34 files changed, 3945 insertions(+) create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 5088497a1..a2d32013a 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -2,6 +2,56 @@ ### Aishell training result(Stateless Transducer) +#### Pruned transducer stateless 7 streaming +[./pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +It's Streaming version of Zipformer1 with Pruned RNNT loss. + +| | test | dev | comment | +|------------------------|------|------|---------------------------------------| +| greedy search | 6.95 | 6.29 | --epoch 44 --avg 15 --max-duration 600 | +| modified beam search | 6.51 | 5.90 | --epoch 44 --avg 15 --max-duration 600 | +| fast beam search | 6.73 | 6.09 | --epoch 44 --avg 15 --max-duration 600 | + +Training command is: + +```bash +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --use-fp16 1 \ + --context-size 1 \ + --max-duration 800 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --enable-musan 0 \ + --spec-aug-time-warp-factor 20 +``` + +**Caution**: It uses `--context-size=1`. + +The decoding command is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 44 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang-dir data/lang_char \ + --context-size 1 \ + --decoding-method $m +done +``` + +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + + + #### Pruned transducer stateless 7 [./pruned_transducer_stateless7](./pruned_transducer_stateless7) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md new file mode 120000 index 000000000..a784292cd --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/README.md \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..f5ae836fd --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,735 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import ContextGraph +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + token_table: k2.SymbolTable, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + token_table: + It maps token ID to a string. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += 30 + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, 30), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + else: + hyp_tokens = [] + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_tokens.append(hyp) + + hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + key = f"beam_size_{params.beam_size}" + if params.has_contexts: + key += f"-context-score-{params.context_score}" + else: + key += "-no-context-words" + return {key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + token_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + token_table: + It maps a token ID to a string. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + token_table=token_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-contexts-words" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts_text = [] + for line in open(params.context_file).readlines(): + contexts_text.append(line.strip()) + contexts = graph_compiler.texts_to_ids(contexts_text) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + test_cuts = aishell.test_cuts() + dev_cuts = aishell.valid_cuts() + + test_dl = aishell.test_dataloaders(test_cuts) + dev_dl = aishell.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] + + import time + + for test_set, test_dl in zip(test_sets, test_dls): + start = time.time() + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + token_table=lexicon.token_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + ) + logging.info(f"Elasped time for {test_set}: {time.time() - start}") + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py new file mode 120000 index 000000000..72e43c297 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 120000 index 000000000..3b36924ef --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py new file mode 120000 index 000000000..eca5e2956 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 120000 index 000000000..57a0cd0a0 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 120000 index 000000000..2acafdc61 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py new file mode 120000 index 000000000..5d9c6ba00 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 120000 index 000000000..457131699 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 120000 index 000000000..2b8fa3cbb --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..e17d4f734 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py new file mode 120000 index 000000000..8eea90e04 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 120000 index 000000000..28bf7bb82 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 120000 index 000000000..c8548d459 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 120000 index 000000000..ae4d9bb04 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 120000 index 000000000..9510b8fde --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 120000 index 000000000..92c3904af --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..1199a61d6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..6b4f183cf --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,627 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +import os +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall import ContextGraph +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + token_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + token_table[result] + for result in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + token_table[result] + for result in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts_text = [] + for line in open(params.context_file).readlines(): + contexts_text.append(line.strip()) + contexts = graph_compiler.texts_to_ids(contexts_text) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + + test_cuts = aishell.test_cuts() + valid_cuts = aishell.valid_cuts() + + test_sets = ["test", "valid"] + cuts = [test_cuts, valid_cuts] + + for test_set, test_cut in zip(test_sets, cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + token_table=lexicon.token_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 120000 index 000000000..1259849e0 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..2e1044658 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1251 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + oov="", + ) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + # T = ((c.num_frames - 7) // 2 + 1) // 2 + # tokens = sp.encode(c.supervisions[0].text, out_type=str) + + # if T < len(tokens): + # logging.warning( + # f"Exclude cut with ID {c.id} from training. " + # f"Number of frames (before subsampling): {c.num_frames}. " + # f"Number of frames (after subsampling): {T}. " + # f"Text: {c.supervisions[0].text}. " + # f"Tokens: {tokens}. " + # f"Number of tokens: {len(tokens)}" + # ) + return False + + return True + + # train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = aishell.valid_cuts() + valid_dl = aishell.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # graph_compiler=graph_compiler, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..88eb34104 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1253 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + oov="", + ) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + # T = ((c.num_frames - 7) // 2 + 1) // 2 + # tokens = sp.encode(c.supervisions[0].text, out_type=str) + + # if T < len(tokens): + # logging.warning( + # f"Exclude cut with ID {c.id} from training. " + # f"Number of frames (before subsampling): {c.num_frames}. " + # f"Number of frames (after subsampling): {T}. " + # f"Text: {c.supervisions[0].text}. " + # f"Tokens: {tokens}. " + # f"Number of tokens: {len(tokens)}" + # ) + # return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = aishell.valid_cuts() + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 120000 index 000000000..12dbda888 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file From d2bd0933b1462fefdc7ac2b41881ae0eb71be873 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 17 Oct 2023 21:22:32 +0800 Subject: [PATCH 014/216] Compatibility with the latest Lhotse (#1314) --- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 +- .../ASR/transducer_stateless_modified-2/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 +- .../ASR_v2/pruned_transducer_stateless7/asr_datamodule.py | 3 +-- egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 +- egs/ami/SURT/dprnn_zipformer/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 +- egs/csj/ASR/local/utils/asr_datamodule.py | 2 +- egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 +- egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py | 2 +- egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless3/asr_datamodule.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py | 2 +- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- egs/mgb2/ASR/conformer_ctc/asr_datamodule.py | 3 +-- egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 +-- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 +- egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py | 2 +- egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 +- 26 files changed, 26 insertions(+), 29 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 49a697bfd..3667c2ad0 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -211,7 +211,7 @@ class Aidatatang_200zhAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index 9c6021a19..cd8dd821c 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -160,7 +160,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py index af37cc175..8f6a88f59 100644 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -218,7 +218,7 @@ class AiShell2AsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index da9da371e..4ad98fb51 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -228,7 +228,7 @@ class Aishell4AsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index 4799da19d..5ad80817a 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -211,7 +211,7 @@ class AlimeetingAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py index 1cfd053c7..9d288218a 100644 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py @@ -208,7 +208,7 @@ class AlimeetingAsrDataModule: logging.info("Enable MUSAN") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -288,7 +288,6 @@ class AlimeetingAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] if self.args.concatenate_cuts: transforms = [ diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py index f7ee9c962..79474f1d8 100644 --- a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -214,7 +214,7 @@ class AmiAsrDataModule: logging.info("Enable MUSAN") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py index 3dd786d33..1549c1631 100644 --- a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py @@ -202,7 +202,7 @@ class AmiAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index 73f2f1dce..546e9f9dd 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -230,7 +230,7 @@ class CommonVoiceAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py index 272486227..042b6ecbf 100644 --- a/egs/csj/ASR/local/utils/asr_datamodule.py +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -256,7 +256,7 @@ class CSJAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index 9d6e3c42a..a93e224d5 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -194,7 +194,7 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 29e72b408..b5b27ce95 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -217,7 +217,7 @@ class GigaSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py index a72df89e0..c1abdbdb5 100644 --- a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py @@ -204,7 +204,7 @@ class LibriCssAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index f8f558ce1..ee7556e49 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -209,7 +209,7 @@ class LibriSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index b7735be85..057624272 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -164,7 +164,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py index 75e153cb0..cd432fd6f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py @@ -217,7 +217,7 @@ class GigaSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 20df469da..c500eb3e5 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -233,7 +233,7 @@ class LibriSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py index 442ff85c2..7753d1674 100644 --- a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py @@ -182,7 +182,6 @@ class MGB2AsrDataModule: cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: - transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") @@ -190,7 +189,7 @@ class MGB2AsrDataModule: cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 3d58ebf3a..02cfa1346 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -219,7 +219,7 @@ class AsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index d94a92503..cf70fc0f8 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -182,7 +182,7 @@ class SPGISpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -261,7 +261,6 @@ class SPGISpeechAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] if self.args.concatenate_cuts: transforms = [ diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index aeeb2ef78..ce8634a1d 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -220,7 +220,7 @@ class SwitchBoardAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 39beffdcf..5269a1778 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -228,7 +228,7 @@ class TAL_CSASRAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index 28d0d3826..d4a9e4bc9 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -205,7 +205,7 @@ class TedLiumAsrDataModule: logging.info("Enable MUSAN") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py index 7c299d601..5d1b3c367 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -157,7 +157,7 @@ class TimitAsrDataModule(DataModule): cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz") logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + transforms = [CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20))] if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c5967f10a..1dbfb9709 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -215,7 +215,7 @@ class WenetSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py index 6362ab7cd..7594fb28e 100644 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -218,7 +218,7 @@ class Xbmu_AmdoAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") From 807816fec0dde1bfa0f0a2f20d36552cc3d84a90 Mon Sep 17 00:00:00 2001 From: Erwan Zerhouni <61225408+ezerhouni@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:07:10 +0200 Subject: [PATCH 015/216] Fix chunk issue for sherpa (#1316) --- egs/librispeech/ASR/zipformer/zipformer.py | 31 +++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 1a174b315..61ae378d8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -17,28 +17,33 @@ # limitations under the License. import copy +import logging import math +import random import warnings from typing import List, Optional, Tuple, Union -import logging + import torch -import random from encoder_interface import EncoderInterface from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, Balancer, BiasNorm, - Dropout2, ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Dropout2, + FloatLike, + ScheduledFloat, Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + convert_num_channels, + limit_param_value, penalize_abs_values_gt, softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, ) from torch import Tensor, nn @@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module): (seq_len, batch_size, _) = x.shape hidden_channels = self.hidden_channels - s, x, y = x.chunk(3, dim=-1) + s, x, y = x.chunk(3, dim=2) # s will go through tanh. @@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module): (seq_len, batch_size, _) = x.shape hidden_channels = self.hidden_channels - s, x, y = x.chunk(3, dim=-1) + s, x, y = x.chunk(3, dim=2) # s will go through tanh. s = self.tanh(s) @@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module): x = self.in_proj(x) # (time, batch, 2*channels) - x, s = x.chunk(2, dim=-1) + x, s = x.chunk(2, dim=2) s = self.balancer1(s) s = self.sigmoid(s) x = self.activation1(x) # identity. From 52c24df61da3d04a6fdcab32d5615c394951279b Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 18 Oct 2023 17:36:14 +0800 Subject: [PATCH 016/216] Fix model avg (#1317) * fix a bug about the model_avg during finetuning by exchanging the order of loading pre-trained model and initializing avg model * only match the exact module prefix --- .../ASR/pruned_transducer_stateless7/finetune.py | 11 +++++++++-- .../ASR/pruned_transducer_stateless2/finetune.py | 8 ++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 4e261dbc1..a7a8ef149 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -655,8 +655,12 @@ def load_model_params( dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key) @@ -1089,6 +1093,9 @@ def run(rank, world_size, args): checkpoints = load_model_params( ckpt=params.finetune_ckpt, model=model, init_modules=modules ) + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) else: assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index c943a84af..ba91980d3 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -498,8 +498,12 @@ def load_model_params( dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key) From 98c5286404a0add86bc6243171fc092ea89c51bb Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Thu, 19 Oct 2023 01:13:50 +0900 Subject: [PATCH 017/216] Fix typo in code-style.rst (#1318) --- docs/source/contributing/code-style.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/contributing/code-style.rst b/docs/source/contributing/code-style.rst index 3baaaeec2..cb08229c3 100644 --- a/docs/source/contributing/code-style.rst +++ b/docs/source/contributing/code-style.rst @@ -38,7 +38,7 @@ Please fix any issues reported by the check tools. .. HINT:: Some of the check tools, i.e., ``black`` and ``isort`` will modify - the files to be commited **in-place**. So please run ``git status`` + the files to be committed **in-place**. So please run ``git status`` after failure to see which file has been modified by the tools before you make any further changes. From 36c60b0cf6b172e7739f5b177e731faa03737967 Mon Sep 17 00:00:00 2001 From: Surav Shrestha <98219089+suravshresth@users.noreply.github.com> Date: Thu, 19 Oct 2023 09:00:18 +0545 Subject: [PATCH 018/216] fix typos in icefall/utils.py (#1319) --- icefall/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index 410340d9d..6479d8f87 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1447,7 +1447,7 @@ def get_parameter_groups_with_lrs( This is for use with the ScaledAdam optimizers (more recent versions that accept lists of named-parameters; we can, if needed, create a version without the names). - It provides a way to specifiy learning-rate scales inside the module, so that if + It provides a way to specify learning-rate scales inside the module, so that if any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will scale the LR of any parameters inside that module or its submodules. Note: you can set module parameters outside the __init__ function, e.g.: @@ -1607,10 +1607,10 @@ def tokenize_by_bpe_model( chars = pattern.split(txt.upper()) mix_chars = [w for w in chars if len(w.strip()) > 0] for ch_or_w in mix_chars: - # ch_or_w is a single CJK charater(i.e., "你"), do nothing. + # ch_or_w is a single CJK character(i.e., "你"), do nothing. if pattern.fullmatch(ch_or_w) is not None: tokens.append(ch_or_w) - # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), + # ch_or_w contains non-CJK characters(i.e., " IT'S OKAY "), # encode ch_or_w using bpe_model. else: for p in sp.encode_as_pieces(ch_or_w): @@ -1624,7 +1624,7 @@ def tokenize_by_CJK_char(line: str) -> str: """ Tokenize a line of text with CJK char. - Note: All return charaters will be upper case. + Note: All return characters will be upper case. Example: input = "你好世界是 hello world 的中文" @@ -1917,7 +1917,7 @@ def parse_bpe_timestamps_and_texts( A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't - be meaningful). Its attribtutes `labels` and `aux_labels` + be meaningful). Its attributes `labels` and `aux_labels` are both BPE tokens. sp: The BPE model. @@ -2045,7 +2045,7 @@ def parse_fsa_timestamps_and_texts( ) -> Tuple[List[Tuple[float, float]], List[List[str]]]: """Parse timestamps (in seconds) and texts for given decoded fsa paths. Currently it supports two cases: - (1) ctc-decoding, the attribtutes `labels` and `aux_labels` + (1) ctc-decoding, the attributes `labels` and `aux_labels` are both BPE tokens. In this case, sp should be provided. (2) HLG-based 1best, the attribtute `labels` is the prediction unit, e.g., phone or BPE tokens; attribute `aux_labels` is the word index. From ce372cce33ad7594baf603f75264950d88fa329c Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:24:31 +0800 Subject: [PATCH 019/216] Update documentation to PromptASR (#1321) --- .../zipformer_prompt_asr/train_baseline.py | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index 7075c9154..32302602c 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,21 +22,35 @@ Usage: -# For mix precision training: +# For mix precision training, using MCP style transcript: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./zipformer/train.py \ +./zipformer_prompt_asr/train_baseline.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir zipformer/exp \ + --exp-dir zipformer_prompt_asr/exp \ + --transcript-style MCP \ + --max-duration 1000 + +# For mix precision training, using UC style transcript: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer_prompt_asr/train_baseline.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer_prompt_asr/exp \ + --transcript-style UC \ --max-duration 1000 # To train a streaming model -./zipformer/train.py \ +./zipformer_prompt_asr/train_baseline.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ @@ -100,7 +115,7 @@ from icefall.utils import ( LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -def get_first( +def get_mixed_cased_with_punc( texts: List[str], pre_texts: List[str], context_list: Optional[str] = None, @@ -479,6 +494,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--transcript-style", + type=str, + default="UC", + choices=["UC", "MCP"], + help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations, + MCP stands for mix-cased text with punctuation. + """, + ) + add_model_arguments(parser) return parser @@ -1223,7 +1248,11 @@ def run(rank, world_size, args): else: sampler_state_dict = None - text_sampling_func = get_upper_only_alpha + if params.transcript_style == "UC": + text_sampling_func = get_upper_only_alpha + else: + text_sampling_func = get_mixed_cased_with_punc + logging.info(f"Using {params.transcript_style} style for training.") logging.info(f"Text sampling func: {text_sampling_func}") train_dl = libriheavy.train_dataloaders( train_cuts, From 543b4cc1ca45f5a6e273cb1440a233e5fc51fa36 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Thu, 19 Oct 2023 15:53:31 +0200 Subject: [PATCH 020/216] small enhanecements (#1322) - add extra check of 'x' and 'x_lens' to earlier point in Transducer model - specify 'utf' encoding when opening text files for writing (recogs, errs) --- egs/librispeech/ASR/pruned_transducer_stateless7/model.py | 3 +++ icefall/utils.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0e59b0f2f..add0e6a18 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -114,6 +114,9 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 + # x.T_dim == max(x_len) + assert x.size(1) == x_lens.max().item(), (x.shape, x_lens, x_lens.max()) + encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) diff --git a/icefall/utils.py b/icefall/utils.py index 6479d8f87..399e8d8b3 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -498,7 +498,7 @@ def store_transcripts( Returns: Return None. """ - with open(filename, "w") as f: + with open(filename, "w", encoding="utf8") as f: for cut_id, ref, hyp in texts: if char_level: ref = list("".join(ref)) @@ -523,7 +523,7 @@ def store_transcripts_and_timestamps( Returns: Return None. """ - with open(filename, "w") as f: + with open(filename, "w", encoding="utf8") as f: for cut_id, ref, hyp, time_ref, time_hyp in texts: print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) From 973dc1026d93c5ce551428459077187a3cd1e0a9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 19 Oct 2023 22:54:00 +0800 Subject: [PATCH 021/216] Make diagnostics.py more error-tolerant and have wider range of supported torch versions (#1234) --- icefall/diagnostics.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 700dc1500..ebf61784e 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -244,12 +244,22 @@ class TensorDiagnostic(object): if stats_type == "eigs": try: - eigs, _ = torch.symeig(stats) + if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'): + eigs, _ = torch.linalg.eigh(stats) + else: + eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print("Error getting eigenvalues, trying another method.") - eigs, _ = torch.eig(stats) - stats = eigs.norm(dim=1).sqrt() + print( + "Error getting eigenvalues, trying another method." + ) + if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'): + eigs, _ = torch.linalg.eig(stats) + eigs = eigs.abs() + else: + eigs, _ = torch.eig(stats) + eigs = eigs.norm(dim=1) + stats = eigs.sqrt() # sqrt so it reflects data magnitude, like stddev- not variance if stats_type in ["rms", "stddev"]: @@ -569,11 +579,10 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in (torch.float32, torch.float16, torch.float64): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate( - o, class_name=get_class_name(_module) - ) - + if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, + class_name=get_class_name(_module)) + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] @@ -587,11 +596,9 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in (torch.float32, torch.float16, torch.float64): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( - o, class_name=get_class_name(_module) - ) - + if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, + class_name=get_class_name(_module)) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) From eef47adee9aa765f41cd63a8d57049b02849f3ad Mon Sep 17 00:00:00 2001 From: Rudra <92840555+Rudra-Ji@users.noreply.github.com> Date: Thu, 19 Oct 2023 20:24:43 +0530 Subject: [PATCH 022/216] fix typo (#1324) --- docs/source/decoding-with-langugage-models/LODR.rst | 2 +- docs/source/model-export/export-ncnn-conv-emformer.rst | 2 +- egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py | 2 +- egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py | 2 +- icefall/utils.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst index 8cc1a624c..b6b6e8cbb 100644 --- a/docs/source/decoding-with-langugage-models/LODR.rst +++ b/docs/source/decoding-with-langugage-models/LODR.rst @@ -56,7 +56,7 @@ during decoding for transducer model: \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - \lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right) -In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR, +In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR, the only difference lies in the choice of source domain LM. According to the original `paper `_, LODR achieves similar performance compared DR in both intra-domain and cross-domain settings. As a bi-gram is much faster to evaluate, LODR is usually much faster. diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst index 4f5535d83..93392aee7 100644 --- a/docs/source/model-export/export-ncnn-conv-emformer.rst +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -125,7 +125,7 @@ Python code. We have also set up ``PATH`` so that you can use .. caution:: Please don't use ``_. - We have made some modifications to the offical `ncnn`_. + We have made some modifications to the official `ncnn`_. We will synchronize ``_ periodically with the official one. diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index bdd1f27bc..2bafe25d6 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -203,7 +203,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="""An interger indicating how many candidates we will keep for each + help="""An integer indicating how many candidates we will keep for each frame. Used only when --decoding-method is beam_search or modified_beam_search.""", ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index ba91980d3..c34f1593d 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -78,7 +78,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): default=None, help=""" Modules to be initialized. It matches all parameters starting with - a specific key. The keys are given with Comma seperated. If None, + a specific key. The keys are given with Comma separated. If None, all modules will be initialised. For example, if you only want to initialise all parameters staring with "encoder", use "encoder"; if you want to initialise parameters starting with encoder or decoder, diff --git a/icefall/utils.py b/icefall/utils.py index 399e8d8b3..a9e8a81b9 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1977,7 +1977,7 @@ def parse_timestamps_and_texts( A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't - be meaningful). Attribtute `labels` is the prediction unit, + be meaningful). Attribute `labels` is the prediction unit, e.g., phone or BPE tokens. Attribute `aux_labels` is the word index. word_table: The word symbol table. From 416852e8a16f7f7f3104e95271c6d109088a416d Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sat, 21 Oct 2023 02:36:59 -0500 Subject: [PATCH 023/216] Add Zipformer recipe for GigaSpeech (#1254) Co-authored-by: Yifan Yang Co-authored-by: yfy62 --- .../run-gigaspeech-zipformer-2023-10-17.sh | 94 ++ .../run-gigaspeech-zipformer-2023-10-17.yml | 126 ++ README.md | 16 +- egs/gigaspeech/ASR/README.md | 1 + egs/gigaspeech/ASR/RESULTS.md | 74 + .../ASR/zipformer/asr_datamodule.py | 444 ++++++ egs/gigaspeech/ASR/zipformer/beam_search.py | 1 + egs/gigaspeech/ASR/zipformer/ctc_decode.py | 847 +++++++++++ egs/gigaspeech/ASR/zipformer/decode.py | 1065 +++++++++++++ egs/gigaspeech/ASR/zipformer/decode_stream.py | 1 + egs/gigaspeech/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-ctc.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/gigaspeech/ASR/zipformer/export-onnx.py | 620 ++++++++ egs/gigaspeech/ASR/zipformer/export.py | 522 +++++++ .../ASR/zipformer/gigaspeech_scoring.py | 1 + .../ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_ctc.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/gigaspeech/ASR/zipformer/joiner.py | 1 + egs/gigaspeech/ASR/zipformer/model.py | 1 + egs/gigaspeech/ASR/zipformer/onnx_check.py | 1 + egs/gigaspeech/ASR/zipformer/onnx_decode.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_H.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HLG.py | 1 + egs/gigaspeech/ASR/zipformer/optim.py | 1 + egs/gigaspeech/ASR/zipformer/pretrained.py | 1 + .../ASR/zipformer/pretrained_ctc.py | 1 + egs/gigaspeech/ASR/zipformer/profile.py | 1 + egs/gigaspeech/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 853 +++++++++++ egs/gigaspeech/ASR/zipformer/subsampling.py | 1 + egs/gigaspeech/ASR/zipformer/test_scaling.py | 1 + .../ASR/zipformer/test_subsampling.py | 1 + egs/gigaspeech/ASR/zipformer/train.py | 1345 +++++++++++++++++ egs/gigaspeech/ASR/zipformer/zipformer.py | 1 + 43 files changed, 6036 insertions(+), 2 deletions(-) create mode 100755 .github/scripts/run-gigaspeech-zipformer-2023-10-17.sh create mode 100644 .github/workflows/run-gigaspeech-zipformer-2023-10-17.yml create mode 100644 egs/gigaspeech/ASR/zipformer/asr_datamodule.py create mode 120000 egs/gigaspeech/ASR/zipformer/beam_search.py create mode 100755 egs/gigaspeech/ASR/zipformer/ctc_decode.py create mode 100755 egs/gigaspeech/ASR/zipformer/decode.py create mode 120000 egs/gigaspeech/ASR/zipformer/decode_stream.py create mode 120000 egs/gigaspeech/ASR/zipformer/decoder.py create mode 120000 egs/gigaspeech/ASR/zipformer/encoder_interface.py create mode 120000 egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py create mode 120000 egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py create mode 100755 egs/gigaspeech/ASR/zipformer/export-onnx.py create mode 100755 egs/gigaspeech/ASR/zipformer/export.py create mode 120000 egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py create mode 120000 egs/gigaspeech/ASR/zipformer/jit_pretrained.py create mode 120000 egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py create mode 120000 egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/gigaspeech/ASR/zipformer/joiner.py create mode 120000 egs/gigaspeech/ASR/zipformer/model.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_check.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_decode.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py create mode 120000 egs/gigaspeech/ASR/zipformer/optim.py create mode 120000 egs/gigaspeech/ASR/zipformer/pretrained.py create mode 120000 egs/gigaspeech/ASR/zipformer/pretrained_ctc.py create mode 120000 egs/gigaspeech/ASR/zipformer/profile.py create mode 120000 egs/gigaspeech/ASR/zipformer/scaling.py create mode 120000 egs/gigaspeech/ASR/zipformer/scaling_converter.py create mode 120000 egs/gigaspeech/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/gigaspeech/ASR/zipformer/streaming_decode.py create mode 120000 egs/gigaspeech/ASR/zipformer/subsampling.py create mode 120000 egs/gigaspeech/ASR/zipformer/test_scaling.py create mode 120000 egs/gigaspeech/ASR/zipformer/test_subsampling.py create mode 100755 egs/gigaspeech/ASR/zipformer/train.py create mode 120000 egs/gigaspeech/ASR/zipformer/zipformer.py diff --git a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh new file mode 100755 index 000000000..6bb0b9ebc --- /dev/null +++ b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/gigaspeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "exp/jit_script.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./zipformer/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./zipformer/jit_pretrained.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --nn-model-filename $repo/exp/jit_script.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for method in greedy_search modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p zipformer/exp + ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh zipformer/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./zipformer/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir zipformer/exp + done + + rm zipformer/exp/*.pt +fi diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml new file mode 100644 index 000000000..7572f4b5f --- /dev/null +++ b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml @@ -0,0 +1,126 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-gigaspeech-zipformer-2023-10-17 +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_gigaspeech_2023_10_17_zipformer-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_gigaspeech_2023_10_17_zipformer: + if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/gigaspeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/gigaspeech/ASR/data/fbank + ls -lh egs/gigaspeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-gigaspeech-zipformer-2023-10-17.sh + + - name: Display decoding results for gigaspeech zipformer + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/gigaspeech/ASR/ + tree ./zipformer/exp + + cd zipformer + echo "results for zipformer" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for gigaspeech zipformer + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 + path: egs/gigaspeech/ASR/zipformer/exp/ diff --git a/README.md b/README.md index da446109d..a14abd023 100644 --- a/README.md +++ b/README.md @@ -148,8 +148,11 @@ in the decoding. ### GigaSpeech -We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc] -and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. +We provide three models for this recipe: + +- [Conformer CTC model][GigaSpeech_conformer_ctc] +- [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. +- [Transducer: Zipformer encoder + Embedding decoder][GigaSpeech_zipformer] #### Conformer CTC @@ -165,6 +168,14 @@ and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned R | fast beam search | 10.50 | 10.69 | | modified beam search | 10.40 | 10.51 | +#### Transducer: Zipformer encoder + Embedding decoder + +| | Dev | Test | +|----------------------|-------|-------| +| greedy search | 10.31 | 10.50 | +| fast beam search | 10.26 | 10.48 | +| modified beam search | 10.25 | 10.38 | + ### Aishell @@ -378,6 +389,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless [GigaSpeech_conformer_ctc]: egs/gigaspeech/ASR/conformer_ctc [GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2 +[GigaSpeech_zipformer]: egs/gigaspeech/ASR/zipformer [Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2 [WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2 [WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5 diff --git a/egs/gigaspeech/ASR/README.md b/egs/gigaspeech/ASR/README.md index 32a0457c6..f0d60898c 100644 --- a/egs/gigaspeech/ASR/README.md +++ b/egs/gigaspeech/ASR/README.md @@ -15,6 +15,7 @@ ln -sfv /path/to/GigaSpeech download/GigaSpeech ## Performance Record | | Dev | Test | |--------------------------------|-------|-------| +| `zipformer` | 10.25 | 10.38 | | `conformer_ctc` | 10.47 | 10.58 | | `pruned_transducer_stateless2` | 10.40 | 10.51 | diff --git a/egs/gigaspeech/ASR/RESULTS.md b/egs/gigaspeech/ASR/RESULTS.md index 7ab565844..841ebdcfa 100644 --- a/egs/gigaspeech/ASR/RESULTS.md +++ b/egs/gigaspeech/ASR/RESULTS.md @@ -1,4 +1,78 @@ ## Results +### zipformer (zipformer + pruned stateless transducer) + +See for more details. + +[zipformer](./zipformer) + +- Non-streaming +- normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +The tensorboard log for training is available at + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 10.31 | 10.50 | --epoch 30 --avg 9 | +| modified_beam_search | 10.25 | 10.38 | --epoch 30 --avg 9 | +| fast_beam_search | 10.26 | 10.48 | --epoch 30 --avg 9 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 0 \ + --subset XL \ + --max-duration 700 \ + --use-transducer 1 \ + --use-ctc 0 \ + --lr-epochs 1 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES=0 + +# greedy search +./zipformer/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method greedy_search + +# modified beam search +./zipformer/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast beam search (one best) +./zipformer/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +``` + ### GigaSpeech BPE training results (Pruned Transducer 2) #### 2022-05-12 diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..c4472ed23 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,444 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import glob +import inspect +import logging +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import lhotse +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class GigaSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + type=str, + default="XL", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--small-dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev (speeds up training)", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get train {self.args.subset} cuts") + if self.args.subset == "XL": + filenames = glob.glob( + f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" + ) + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + sorted_filenames = [f[1] for f in idx_filenames] + logging.info( + f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" + ) + cuts_train = lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) + else: + path = ( + self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" + ) + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + ) + if self.args.small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + ) diff --git a/egs/gigaspeech/ASR/zipformer/beam_search.py b/egs/gigaspeech/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..aa51036d5 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -0,0 +1,847 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(2) 1best +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(3) nbest +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(4) nbest-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(5) whole-lattice-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from train import add_model_arguments, get_params, get_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.6, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + test_clean_cuts = gigaspeech.test_clean_cuts() + test_other_cuts = gigaspeech.test_other_cuts() + + test_clean_dl = gigaspeech.test_dataloaders(test_clean_cuts) + test_other_dl = gigaspeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/zipformer/decode.py b/egs/gigaspeech/ASR/zipformer/decode.py new file mode 100755 index 000000000..3a0c71484 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/decode.py @@ -0,0 +1,1065 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/zipformer/decode_stream.py b/egs/gigaspeech/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/decoder.py b/egs/gigaspeech/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/encoder_interface.py b/egs/gigaspeech/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py b/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 000000000..f9d756352 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py b/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx.py b/egs/gigaspeech/ASR/zipformer/export-onnx.py new file mode 100755 index 000000000..0f78cfe5b --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export-onnx.py @@ -0,0 +1,620 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/gigaspeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/gigaspeech/ASR/zipformer/export.py b/egs/gigaspeech/ASR/zipformer/export.py new file mode 100755 index 000000000..e45c96b57 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export.py @@ -0,0 +1,522 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for gigaspeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/gigaspeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/gigaspeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py b/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py new file mode 120000 index 000000000..a6a4d12b1 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py @@ -0,0 +1 @@ +../conformer_ctc/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 000000000..9a8da5844 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/joiner.py b/egs/gigaspeech/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/model.py b/egs/gigaspeech/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_check.py b/egs/gigaspeech/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_decode.py b/egs/gigaspeech/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py new file mode 120000 index 000000000..a3183ebf6 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py new file mode 120000 index 000000000..a4fd76ac2 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py new file mode 120000 index 000000000..f805e3761 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py new file mode 120000 index 000000000..8343d5079 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/optim.py b/egs/gigaspeech/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/pretrained.py b/egs/gigaspeech/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py new file mode 120000 index 000000000..c2f6f6fc3 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/profile.py b/egs/gigaspeech/ASR/zipformer/profile.py new file mode 120000 index 000000000..c93adbd14 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/scaling.py b/egs/gigaspeech/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/scaling_converter.py b/egs/gigaspeech/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py b/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..a76788859 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py @@ -0,0 +1,853 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import GigaSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + test_sets = ["dev", "test"] + test_cuts = [dev_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/zipformer/subsampling.py b/egs/gigaspeech/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/test_scaling.py b/egs/gigaspeech/ASR/zipformer/test_scaling.py new file mode 120000 index 000000000..715798436 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/test_scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/test_subsampling.py b/egs/gigaspeech/ASR/zipformer/test_subsampling.py new file mode 120000 index 000000000..bf0ee3d11 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/test_subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py new file mode 100755 index 000000000..d8ff4fecc --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -0,0 +1,1345 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=1, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 500, + "reset_interval": 2000, + "valid_interval": 20000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + gigaspeech = GigaSpeechAsrDataModule(args) + + train_cuts = gigaspeech.train_cuts() + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.dev_cuts() + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/zipformer/zipformer.py b/egs/gigaspeech/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 902dc2364a693ce7c6b939a0c9cf64382f767147 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 22 Oct 2023 23:25:06 +0800 Subject: [PATCH 024/216] Update docker for torch 2.1 (#1326) --- .github/workflows/build-docker-image.yml | 9 ++- .github/workflows/run-docker-image.yml | 15 ++++- .github/workflows/run-yesno-recipe.yml | 4 +- docker/torch1.12.1-cuda11.3.dockerfile | 5 +- docker/torch1.13.0-cuda11.6.dockerfile | 5 +- docker/torch1.9.0-cuda10.2.dockerfile | 3 +- docker/torch2.0.0-cuda11.7.dockerfile | 5 +- docker/torch2.1.0-cuda11.8.dockerfile | 71 ++++++++++++++++++++++++ docker/torch2.1.0-cuda12.1.dockerfile | 71 ++++++++++++++++++++++++ docs/source/docker/intro.rst | 2 + 10 files changed, 179 insertions(+), 11 deletions(-) create mode 100644 docker/torch2.1.0-cuda11.8.dockerfile create mode 100644 docker/torch2.1.0-cuda12.1.dockerfile diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index 327f0ee45..e5d96dcdf 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout @@ -30,6 +30,13 @@ jobs: image=${{ matrix.image }} mv -v ./docker/$image.dockerfile ./Dockerfile + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + - name: Log in to Docker Hub uses: docker/login-action@v2 with: diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index 12604a132..d048923b6 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v2 @@ -30,8 +30,15 @@ jobs: uname -a cat /etc/*release + find / -name libcuda* 2>/dev/null + + ls -lh /usr/local/ + ls -lh /usr/local/cuda* + nvcc --version + ls -lh /usr/local/cuda-*/compat/* + # For torch1.9.0-cuda10.2 export LD_LIBRARY_PATH=/usr/local/cuda-10.2/compat:$LD_LIBRARY_PATH @@ -41,6 +48,12 @@ jobs: # For torch2.0.0-cuda11.7 export LD_LIBRARY_PATH=/usr/local/cuda-11.7/compat:$LD_LIBRARY_PATH + # For torch2.1.0-cuda11.8 + export LD_LIBRARY_PATH=/usr/local/cuda-11.8/compat:$LD_LIBRARY_PATH + + # For torch2.1.0-cuda12.1 + export LD_LIBRARY_PATH=/usr/local/cuda-12.1/compat:$LD_LIBRARY_PATH + which nvcc cuda_dir=$(dirname $(which nvcc)) diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 7d55a50e1..9ac848535 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -64,8 +64,8 @@ jobs: pip uninstall -y protobuf pip install --no-binary protobuf protobuf==3.20.* - pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl - pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html + pip install --no-deps --force-reinstall k2==1.24.4.dev20231021+cpu.torch1.13.1 -f https://k2-fsa.github.io/k2/cpu.html + pip install kaldifeat==1.25.1.dev20231022+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html - name: Run yesno recipe shell: bash diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index 5338bdca7..ed746abe3 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive -ARG K2_VERSION="1.24.3.dev20230725+cuda11.3.torch1.12.1" -ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.3.torch1.12.1" +# python 3.7 +ARG K2_VERSION="1.24.4.dev20230725+cuda11.3.torch1.12.1" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.3.torch1.12.1" ARG TORCHAUDIO_VERSION="0.12.1+cu113" LABEL authors="Fangjun Kuang " diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index 4d2f96c8e..9657866e5 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive -ARG K2_VERSION="1.24.3.dev20230725+cuda11.6.torch1.13.0" -ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.6.torch1.13.0" +# python 3.9 +ARG K2_VERSION="1.24.4.dev20231021+cuda11.6.torch1.13.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.6.torch1.13.0" ARG TORCHAUDIO_VERSION="0.13.0+cu116" LABEL authors="Fangjun Kuang " diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index a7cef6dc8..a92af7ad0 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive +# python 3.7 ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0" -ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda10.2.torch1.9.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda10.2.torch1.9.0" ARG TORCHAUDIO_VERSION="0.9.0" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index d91fbc24f..07296e6f0 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive -ARG K2_VERSION="1.24.3.dev20230718+cuda11.7.torch2.0.0" -ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.7.torch2.0.0" +# python 3.10 +ARG K2_VERSION="1.24.4.dev20231021+cuda11.7.torch2.0.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.7.torch2.0.0" ARG TORCHAUDIO_VERSION="2.0.0+cu117" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile new file mode 100644 index 000000000..e500e9a6a --- /dev/null +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -0,0 +1,71 @@ +FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20231021+cuda11.8.torch2.1.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.8.torch2.1.0" +ARG TORCHAUDIO_VERSION="2.1.0+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile new file mode 100644 index 000000000..c3f12323e --- /dev/null +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -0,0 +1,71 @@ +FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20231021+cuda12.1.torch2.1.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda12.1.torch2.1.0" +ARG TORCHAUDIO_VERSION="2.1.0+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index b09247d85..9ead0df00 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -30,6 +30,8 @@ which will give you something like below: .. code-block:: bash + "torch2.1.0-cuda12.1" + "torch2.1.0-cuda11.8" "torch2.0.0-cuda11.7" "torch1.12.1-cuda11.3" "torch1.9.0-cuda10.2" From 92ef561ff71e531f243ff432561851bb4b93390a Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 24 Oct 2023 01:10:50 +0800 Subject: [PATCH 025/216] Minor fixes for torch.jit.script support (#1329) --- egs/aishell/ASR/transducer_stateless/decoder.py | 4 ++++ egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py | 4 ++++ egs/librispeech/ASR/pruned_transducer_stateless/decoder.py | 4 ++++ egs/librispeech/ASR/transducer_stateless/decoder.py | 4 ++++ egs/librispeech/ASR/zipformer/decoder.py | 5 ++++- 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index 70e9e6c96..130f080ec 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -70,6 +70,10 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py index 93e0f9f7e..8a55eb5c8 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py @@ -95,6 +95,10 @@ class Decoder(nn.Module): max_abs=1.0, prob=0.05, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 49b82c433..03847b449 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -74,6 +74,10 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() self.output_linear = nn.Linear(embedding_dim, vocab_size) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index a182d91e2..ac6292f63 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -71,6 +71,10 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index e77e54118..492d63fc5 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from scaling import Balancer @@ -95,6 +94,10 @@ class Decoder(nn.Module): max_abs=1.0, prob=0.05, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ From f9980aa606d2ea9bf3d73d65309fa161b2bc4765 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 24 Oct 2023 08:17:17 +0800 Subject: [PATCH 026/216] minor fixes (#1332) --- egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py | 1 + egs/librispeech/ASR/zipformer/decoder.py | 1 + 2 files changed, 2 insertions(+) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py index 8a55eb5c8..91f167204 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py @@ -99,6 +99,7 @@ class Decoder(nn.Module): # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` # when inference with torch.jit.script and context_size == 1 self.conv = nn.Identity() + self.balancer2 = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index 492d63fc5..7ce44495b 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -98,6 +98,7 @@ class Decoder(nn.Module): # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` # when inference with torch.jit.script and context_size == 1 self.conv = nn.Identity() + self.balancer2 = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ From 4b791ced78aa5b6d2ccc2d78458a3ed984b26e7f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 24 Oct 2023 10:38:56 +0800 Subject: [PATCH 027/216] Fix CI tests (#1333) --- requirements-ci.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 1eba69764..e1232a768 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -12,7 +12,7 @@ graphviz==0.19.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.13.1+cpu six --f https://k2-fsa.org/nightly/ k2==1.23.4.dev20230319+cpu.torch1.13.1 +-f https://k2-fsa.github.io/k2/cpu.html k2==1.24.4.dev20231022+cpu.torch1.13.1 git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 From 3fb99400cf2c691f5c666fecd1415340820364a6 Mon Sep 17 00:00:00 2001 From: hairyputtar <148847552+hairyputtar@users.noreply.github.com> Date: Tue, 24 Oct 2023 13:17:25 +0530 Subject: [PATCH 028/216] fix typos (#1336) * fix typo * fix typo * Update pruned_transducer_stateless.rst --- docs/source/contributing/how-to-create-a-recipe.rst | 2 +- docs/source/recipes/Streaming-ASR/introduction.rst | 2 +- .../librispeech/pruned_transducer_stateless.rst | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/contributing/how-to-create-a-recipe.rst b/docs/source/contributing/how-to-create-a-recipe.rst index a30fb9056..168a856c3 100644 --- a/docs/source/contributing/how-to-create-a-recipe.rst +++ b/docs/source/contributing/how-to-create-a-recipe.rst @@ -3,7 +3,7 @@ How to create a recipe .. HINT:: - Please read :ref:`follow the code style` to adjust your code sytle. + Please read :ref:`follow the code style` to adjust your code style. .. CAUTION:: diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst index ac77a51d1..28f5b8fbf 100644 --- a/docs/source/recipes/Streaming-ASR/introduction.rst +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -32,7 +32,7 @@ In icefall, we implement the streaming conformer the way just like what `WeNet < .. HINT:: If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer to `this pull request `_. After adding the code needed by streaming training, - you have to re-train it with the extra arguments metioned in the docs above to get a streaming model. + you have to re-train it with the extra arguments mentioned in the docs above to get a streaming model. Streaming Emformer diff --git a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst index 2ca70bcf3..d6e424e2f 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -584,7 +584,7 @@ The following shows two examples (for the two types of checkpoints): - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and `espnet/nets/beam_search_transducer.py `_ - is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + is used as a reference. Basically, it keeps topk states for each frame, and expands the kept states with their own contexts to next frame. - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it @@ -648,7 +648,7 @@ command to extract ``model.state_dict()``. .. caution:: ``--streaming-model`` and ``--causal-convolution`` require to be True to export - a streaming mdoel. + a streaming model. It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``. @@ -697,7 +697,7 @@ Export model using ``torch.jit.script()`` .. caution:: ``--streaming-model`` and ``--causal-convolution`` require to be True to export - a streaming mdoel. + a streaming model. It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later load it by ``torch.jit.load("cpu_jit.pt")``. From d76c3fe4726ccf7f1f53f5e0f0607aa3dfec12c0 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 24 Oct 2023 16:24:46 +0800 Subject: [PATCH 029/216] Migrate zipformer model to other Chinese datasets (#1216) added zipformer recipe for AISHELL-1 --- ...pruned-transducer-stateless3-2022-06-20.sh | 2 +- .../run-aishell-zipformer-2023-10-24.sh | 103 ++ .../run-aishell-zipformer-2023-10-24.yml | 95 ++ egs/aidatatang_200zh/ASR/prepare.sh | 4 +- .../asr_datamodule.py | 3 +- egs/aishell/ASR/README.md | 6 +- egs/aishell/ASR/RESULTS.md | 158 +- egs/aishell/ASR/prepare.sh | 3 +- egs/aishell/ASR/zipformer/__init__.py | 0 egs/aishell/ASR/zipformer/asr_datamodule.py | 1 + egs/aishell/ASR/zipformer/beam_search.py | 1 + egs/aishell/ASR/zipformer/decode.py | 814 ++++++++++ egs/aishell/ASR/zipformer/decode_stream.py | 1 + egs/aishell/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/aishell/ASR/zipformer/export-onnx.py | 1 + egs/aishell/ASR/zipformer/export.py | 1 + egs/aishell/ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/aishell/ASR/zipformer/joiner.py | 1 + egs/aishell/ASR/zipformer/model.py | 1 + egs/aishell/ASR/zipformer/onnx_check.py | 1 + egs/aishell/ASR/zipformer/onnx_decode.py | 286 ++++ .../zipformer/onnx_pretrained-streaming.py | 1 + egs/aishell/ASR/zipformer/onnx_pretrained.py | 1 + egs/aishell/ASR/zipformer/optim.py | 1 + egs/aishell/ASR/zipformer/pretrained.py | 1 + egs/aishell/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + egs/aishell/ASR/zipformer/streaming_decode.py | 880 +++++++++++ egs/aishell/ASR/zipformer/subsampling.py | 1 + egs/aishell/ASR/zipformer/train.py | 1350 +++++++++++++++++ egs/aishell/ASR/zipformer/zipformer.py | 1 + egs/aishell2/ASR/README.md | 6 +- egs/aishell2/ASR/RESULTS.md | 8 +- egs/aishell2/ASR/prepare.sh | 6 +- egs/aishell4/ASR/README.md | 6 +- egs/aishell4/ASR/prepare.sh | 4 +- .../asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/local | 1 + 42 files changed, 3741 insertions(+), 18 deletions(-) create mode 100755 .github/scripts/run-aishell-zipformer-2023-10-24.sh create mode 100644 .github/workflows/run-aishell-zipformer-2023-10-24.yml create mode 100644 egs/aishell/ASR/zipformer/__init__.py create mode 120000 egs/aishell/ASR/zipformer/asr_datamodule.py create mode 120000 egs/aishell/ASR/zipformer/beam_search.py create mode 100755 egs/aishell/ASR/zipformer/decode.py create mode 120000 egs/aishell/ASR/zipformer/decode_stream.py create mode 120000 egs/aishell/ASR/zipformer/decoder.py create mode 120000 egs/aishell/ASR/zipformer/encoder_interface.py create mode 120000 egs/aishell/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/aishell/ASR/zipformer/export-onnx.py create mode 120000 egs/aishell/ASR/zipformer/export.py create mode 120000 egs/aishell/ASR/zipformer/jit_pretrained.py create mode 120000 egs/aishell/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/aishell/ASR/zipformer/joiner.py create mode 120000 egs/aishell/ASR/zipformer/model.py create mode 120000 egs/aishell/ASR/zipformer/onnx_check.py create mode 100755 egs/aishell/ASR/zipformer/onnx_decode.py create mode 120000 egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/aishell/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/aishell/ASR/zipformer/optim.py create mode 120000 egs/aishell/ASR/zipformer/pretrained.py create mode 120000 egs/aishell/ASR/zipformer/scaling.py create mode 120000 egs/aishell/ASR/zipformer/scaling_converter.py create mode 120000 egs/aishell/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/aishell/ASR/zipformer/streaming_decode.py create mode 120000 egs/aishell/ASR/zipformer/subsampling.py create mode 100755 egs/aishell/ASR/zipformer/train.py create mode 120000 egs/aishell/ASR/zipformer/zipformer.py create mode 120000 egs/aishell4/ASR/pruned_transducer_stateless5/local diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index 4c393f6be..c3640cfde 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -18,8 +18,8 @@ log "Downloading pre-commputed fbank from $fbank_url" git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests ln -s $PWD/aishell-test-dev-manifests/data . -log "Downloading pre-trained model from $repo_url" repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 +log "Downloading pre-trained model from $repo_url" git clone $repo_url repo=$(basename $repo_url) diff --git a/.github/scripts/run-aishell-zipformer-2023-10-24.sh b/.github/scripts/run-aishell-zipformer-2023-10-24.sh new file mode 100755 index 000000000..865e29799 --- /dev/null +++ b/.github/scripts/run-aishell-zipformer-2023-10-24.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell/ASR + +git lfs install + +fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests +log "Downloading pre-commputed fbank from $fbank_url" + +git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests +ln -s $PWD/aishell-test-dev-manifests/data . + +log "=======================" +log "CI testing large model" +repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/ +log "Downloading pre-trained model from $repo_url" +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done + +log "=======================" +log "CI testing medium model" +repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/ +log "Downloading pre-trained model from $repo_url" +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + + +for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done + + +log "=======================" +log "CI testing small model" +repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/ +log "Downloading pre-trained model from $repo_url" +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + + +for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done + diff --git a/.github/workflows/run-aishell-zipformer-2023-10-24.yml b/.github/workflows/run-aishell-zipformer-2023-10-24.yml new file mode 100644 index 000000000..f2fb44a5f --- /dev/null +++ b/.github/workflows/run-aishell-zipformer-2023-10-24.yml @@ -0,0 +1,95 @@ +# Copyright 2023 Zengrui Jin (Xiaomi Corp.) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-aishell-zipformer-2023-10-24 + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_aishell_zipformer_2023_10_24-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_aishell_zipformer_2023_10_24: + if: github.event.label.name == 'ready' || github.event.label.name == 'zipformer' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-aishell-zipformer-2023-10-24.sh + + \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 2eb0b3718..40ee2eb97 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -7,6 +7,8 @@ set -eou pipefail stage=-1 stop_stage=100 +perturb_speed=true + # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -77,7 +79,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for aidatatang_200zh" if [ ! -f data/fbank/.aidatatang_200zh.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aidatatang_200zh.py --perturb-speed True + ./local/compute_fbank_aidatatang_200zh.py --perturb-speed ${perturb_speed} touch data/fbank/.aidatatang_200zh.done fi fi diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 3667c2ad0..d491996b2 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -102,7 +102,7 @@ class Aidatatang_200zhAsrDataModule: group.add_argument( "--bucketing-sampler", type=str2bool, - default=True, + default=False, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) @@ -289,6 +289,7 @@ class Aidatatang_200zhAsrDataModule: shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=True, + buffer_size=50000, ) else: logging.info("Using SimpleCutSampler.") diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index b9064cede..176f065e5 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -1,10 +1,12 @@ # Introduction -Please refer to -for how to run models in this recipe. +Please refer to for how to run models in this recipe. +Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co., Ltd. +400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition. +(From [Open Speech and Language Resources](https://www.openslr.org/33/)) # Transducers diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index a2d32013a..0b22f41a1 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,6 +1,162 @@ ## Results -### Aishell training result(Stateless Transducer) +### Aishell training result (Stateless Transducer) + +#### Zipformer (Non-streaming) + +[./zipformer](./zipformer) + +It's reworked Zipformer with Pruned RNNT loss. +**Caution**: It uses `--context-size=1`. + +##### normal-scaled model, number of model parameters: 73412551, i.e., 73.41 M + +| | test | dev | comment | +|------------------------|------|------|-----------------------------------------| +| greedy search | 4.67 | 4.37 | --epoch 55 --avg 17 | +| modified beam search | 4.40 | 4.13 | --epoch 55 --avg 17 | +| fast beam search | 4.60 | 4.31 | --epoch 55 --avg 17 | + +Command for training is: +```bash +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1" + +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --start-epoch 1 \ + --use-fp16 1 \ + --context-size 1 \ + --enable-musan 0 \ + --exp-dir zipformer/exp \ + --max-duration 1000 \ + --enable-musan 0 \ + --base-lr 0.045 \ + --lr-batches 7500 \ + --lr-epochs 18 \ + --spec-aug-time-warp-factor 20 +``` + +Command for decoding is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./zipformer/decode.py \ + --epoch 55 \ + --avg 17 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --context-size 1 \ + --decoding-method $m +done +``` +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + + +##### small-scaled model, number of model parameters: 30167139, i.e., 30.17 M + +| | test | dev | comment | +|------------------------|------|------|-----------------------------------------| +| greedy search | 4.97 | 4.67 | --epoch 55 --avg 21 | +| modified beam search | 4.67 | 4.40 | --epoch 55 --avg 21 | +| fast beam search | 4.85 | 4.61 | --epoch 55 --avg 21 | + +Command for training is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" + +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --start-epoch 1 \ + --use-fp16 1 \ + --context-size 1 \ + --exp-dir zipformer/exp-small \ + --enable-musan 0 \ + --base-lr 0.045 \ + --lr-batches 7500 \ + --lr-epochs 18 \ + --spec-aug-time-warp-factor 20 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --max-duration 1200 +``` + +Command for decoding is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./zipformer/decode.py \ + --epoch 55 \ + --avg 21 \ + --exp-dir ./zipformer/exp-small \ + --lang-dir data/lang_char \ + --context-size 1 \ + --decoding-method $m \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 +done +``` + +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + +##### large-scaled model, number of model parameters: 157285130, i.e., 157.29 M + +| | test | dev | comment | +|------------------------|------|------|-----------------------------------------| +| greedy search | 4.49 | 4.22 | --epoch 56 --avg 23 | +| modified beam search | 4.28 | 4.03 | --epoch 56 --avg 23 | +| fast beam search | 4.44 | 4.18 | --epoch 56 --avg 23 | + +Command for training is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" + +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --use-fp16 1 \ + --context-size 1 \ + --exp-dir ./zipformer/exp-large \ + --enable-musan 0 \ + --lr-batches 7500 \ + --lr-epochs 18 \ + --spec-aug-time-warp-factor 20 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 800 +``` + +Command for decoding is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./zipformer/decode.py \ + --epoch 56 \ + --avg 23 \ + --exp-dir ./zipformer/exp-large \ + --lang-dir data/lang_char \ + --context-size 1 \ + --decoding-method $m \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 +done +``` + +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + #### Pruned transducer stateless 7 streaming [./pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index d5dbe5726..4feed55a8 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -8,6 +8,7 @@ set -eou pipefail nj=15 stage=-1 stop_stage=11 +perturb_speed=true # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -114,7 +115,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell" if [ ! -f data/fbank/.aishell.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell.py --perturb-speed True + ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} touch data/fbank/.aishell.done fi fi diff --git a/egs/aishell/ASR/zipformer/__init__.py b/egs/aishell/ASR/zipformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/aishell/ASR/zipformer/asr_datamodule.py b/egs/aishell/ASR/zipformer/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/aishell/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/beam_search.py b/egs/aishell/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/aishell/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/decode.py b/egs/aishell/ASR/zipformer/decode.py new file mode 100755 index 000000000..1968904ae --- /dev/null +++ b/egs/aishell/ASR/zipformer/decode.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/decode_stream.py b/egs/aishell/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/aishell/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/decoder.py b/egs/aishell/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/aishell/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/encoder_interface.py b/egs/aishell/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/aishell/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export-onnx-streaming.py b/egs/aishell/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/aishell/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export-onnx.py b/egs/aishell/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/aishell/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export.py b/egs/aishell/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/aishell/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/jit_pretrained.py b/egs/aishell/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/aishell/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py b/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/joiner.py b/egs/aishell/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/aishell/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/model.py b/egs/aishell/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/aishell/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_check.py b/egs/aishell/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/aishell/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_decode.py b/egs/aishell/ASR/zipformer/onnx_decode.py new file mode 100755 index 000000000..17c6eceb4 --- /dev/null +++ b/egs/aishell/ASR/zipformer/onnx_decode.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. +""" + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from lhotse.cut import Cut +from onnx_pretrained import OnnxModel, greedy_search + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: k2.SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + Mapping ids to tokens. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [[token_table[h] for h in hyp] for hyp in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: k2.SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + Mapping ids to tokens. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = list(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = k2.SymbolTable.from_file(args.tokens) + assert token_table[0] == "" + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_net_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py b/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_pretrained.py b/egs/aishell/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/aishell/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/optim.py b/egs/aishell/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/aishell/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/pretrained.py b/egs/aishell/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/aishell/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/scaling.py b/egs/aishell/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/aishell/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/scaling_converter.py b/egs/aishell/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/aishell/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/streaming_beam_search.py b/egs/aishell/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/aishell/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..c3820447a --- /dev/null +++ b/egs/aishell/ASR/zipformer/streaming_decode.py @@ -0,0 +1,880 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +from asr_datamodule import AishellAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="Path to the lang dir(containing lexicon, tokens, etc.)", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + encoder_out=encoder_out, + streams=decode_streams, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + lexicon: + The Lexicon. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + if audio.max() > 1: + logging.warning( + f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." + f"Clipping to [-1, 1]." + ) + audio = np.clip(audio, -1, 1) + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + list(decode_streams[i].ground_truth.strip()), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + key = f"greedy_search_{key}" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_{key}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}_{key}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + + dev_cuts = aishell.valid_cuts() + test_cuts = aishell.test_cuts() + + test_sets = ["dev", "test"] + test_cuts = [dev_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + ) + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/subsampling.py b/egs/aishell/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/aishell/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py new file mode 100755 index 000000000..7e7b02829 --- /dev/null +++ b/egs/aishell/ASR/zipformer/train.py @@ -0,0 +1,1350 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --exp-dir zipformer/exp \ + --training-subset L + --lr-epochs 1.5 \ + --max-duration 350 + +# For mix precision training: + +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --training-subset L \ + --lr-epochs 1.5 \ + --max-duration 750 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="""Maximum left-contexts for causal training, measured in frames which will + be converted to a number of chunks. If splitting into chunks, + chunk left-context frames will be chosen randomly from this list; else not relevant.""", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + valid_cuts = aishell.valid_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 12.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + The compiler to encode texts to ids. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = supervisions["text"] + y = graph_compiler.texts_to_ids(texts) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/zipformer.py b/egs/aishell/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/aishell/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/aishell2/ASR/README.md b/egs/aishell2/ASR/README.md index ba38a1ec7..4e786af11 100644 --- a/egs/aishell2/ASR/README.md +++ b/egs/aishell2/ASR/README.md @@ -1,7 +1,11 @@ # Introduction -This recipe includes some different ASR models trained with Aishell2. +This recipe contains various different ASR models trained with Aishell2. + +In AISHELL-2, 1000 hours of clean read-speech data from iOS is published, which is free for academic usage. On top of AISHELL-2 corpus, an improved recipe is developed and released, containing key components for industrial applications, such as Chinese word segmentation, flexible vocabulary expension and phone set transformation etc. Pipelines support various state-of-the-art techniques, such as time-delayed neural networks and Lattic-Free MMI objective funciton. In addition, we also release dev and test data from other channels (Android and Mic). + +(From [AISHELL-2: Transforming Mandarin ASR Research Into Industrial Scale](https://arxiv.org/abs/1808.10583)) [./RESULTS.md](./RESULTS.md) contains the latest results. diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md index 7114bd5f5..32ad74b50 100644 --- a/egs/aishell2/ASR/RESULTS.md +++ b/egs/aishell2/ASR/RESULTS.md @@ -1,8 +1,8 @@ ## Results -### Aishell2 char-based training results (Pruned Transducer 5) +### Aishell2 char-based training results -#### 2022-07-11 +#### Pruned transducer stateless 5 Using the codes from this commit https://github.com/k2-fsa/icefall/pull/465. @@ -41,9 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" The decoding command is: ```bash -for method in greedy_search modified_beam_search \ - fast_beam_search fast_beam_search_nbest \ - fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do +for method in greedy_search modified_beam_search fast_beam_search fast_beam_search_nbest fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do ./pruned_transducer_stateless5/decode.py \ --epoch 25 \ --avg 5 \ diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 42631c864..6eb6268f5 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -7,7 +7,9 @@ set -eou pipefail nj=30 stage=0 -stop_stage=5 +stop_stage=7 +perturb_speed=true + # We assume dl_dir (download dir) contains the following # directories and files. If not, you need to apply aishell2 through @@ -101,7 +103,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell2" if [ ! -f data/fbank/.aishell2.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell2.py --perturb-speed True + ./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} touch data/fbank/.aishell2.done fi fi diff --git a/egs/aishell4/ASR/README.md b/egs/aishell4/ASR/README.md index 3744032f8..67fa17790 100644 --- a/egs/aishell4/ASR/README.md +++ b/egs/aishell4/ASR/README.md @@ -1,7 +1,11 @@ # Introduction -This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets). +This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets). + +The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. + +(From [Open Speech and Language Resources](https://www.openslr.org/111/)) [./RESULTS.md](./RESULTS.md) contains the latest results. diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index 1b1ec0005..361cc26ab 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -7,6 +7,8 @@ set -eou pipefail stage=-1 stop_stage=100 +perturb_speed=true + # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -107,7 +109,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for aishell4" if [ ! -f data/fbank/.aishell4.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell4.py --perturb-speed True + ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} touch data/fbank/.aishell4.done fi fi diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 4ad98fb51..e6db2651f 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -306,7 +306,7 @@ class Aishell4AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=30000, + buffer_size=100000, drop_last=self.args.drop_last, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/local b/egs/aishell4/ASR/pruned_transducer_stateless5/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/local @@ -0,0 +1 @@ +../local \ No newline at end of file From f82bccfd63d4f02fbe5050e3c2d972dc69656215 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 24 Oct 2023 19:04:09 +0800 Subject: [PATCH 030/216] Support CTC decoding for `multi-zh_hans` recipe (#1313) --- .../scripts/run-multi-zh_hans-zipformer.sh | 44 ++ .../workflows/run-multi-zh_hans-zipformer.yml | 2 +- egs/multi_zh-hans/ASR/RESULTS.md | 43 +- egs/multi_zh-hans/ASR/zipformer/ctc_decode.py | 625 ++++++++++++++++++ 4 files changed, 709 insertions(+), 5 deletions(-) create mode 100755 egs/multi_zh-hans/ASR/zipformer/ctc_decode.py diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-zh_hans-zipformer.sh index 2bc3137d8..dd32a94f8 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-zh_hans-zipformer.sh @@ -10,6 +10,7 @@ log() { cd egs/multi_zh-hans/ASR +log "==== Test icefall-asr-multi-zh-hans-zipformer-2023-9-2 ====" repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ log "Downloading pre-trained model from $repo_url" @@ -49,3 +50,46 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav done + +log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ====" +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-20.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + + +./zipformer/pretrained.py \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-ctc 1 \ + --method greedy_search \ +$repo/test_wavs/DEV_T0000000000.wav \ +$repo/test_wavs/DEV_T0000000001.wav \ +$repo/test_wavs/DEV_T0000000002.wav + +for method in modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --use-ctc 1 \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav +done \ No newline at end of file diff --git a/.github/workflows/run-multi-zh_hans-zipformer.yml b/.github/workflows/run-multi-zh_hans-zipformer.yml index 4ec81585f..72c0775a7 100644 --- a/.github/workflows/run-multi-zh_hans-zipformer.yml +++ b/.github/workflows/run-multi-zh_hans-zipformer.yml @@ -29,7 +29,7 @@ concurrency: jobs: run_multi-zh_hans_zipformer: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index 31fbd9700..5133229a7 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -4,6 +4,41 @@ This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall. +#### Non-streaming (with CTC head) + +Best results (num of params : ~69M): + +The training command: + +``` +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 20 \ + --use-fp16 1 \ + --max-duration 600 \ + --num-workers 8 \ + --use-ctc 1 +``` + +The decoding command: + +``` +./zipformer/decode.py \ + --epoch 20 \ + --avg 1 \ + --use-ctc 1 +``` + +Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled). + +| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| +| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| CTC Decoding | 14.57 | 15.26 | 72.85 | 69.70 | 12.87 | 13.76 | 23.56 | 25.55 | 71.75 | 22.35 | 19.34 | 42.38 | 26.90 | 48.71 | 64.88 | 67.29 | 54.24 | +| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 | + +Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ + #### Non-streaming Best results (num of params : ~69M): @@ -29,10 +64,10 @@ The decoding command: Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled). -| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | |--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | +| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | -The pre-trained model is available here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2 +Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ diff --git a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..a7cd7ce43 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import get_lattice, one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ("ctc-decoding",) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=True, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + + G = None + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + test_sets_cuts = multi_dataset.test_cuts() + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() From 1814bbb0e7afcfcfe495322d2abd4fbfb21510c4 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 25 Oct 2023 00:03:33 +0800 Subject: [PATCH 031/216] typo fixed (#1334) --- egs/aishell/ASR/prepare.sh | 2 +- egs/aishell2/ASR/prepare.sh | 2 +- egs/gigaspeech/ASR/prepare.sh | 2 +- egs/librispeech/ASR/prepare.sh | 2 +- egs/librispeech/WSASR/prepare.sh | 2 +- egs/mgb2/ASR/prepare.sh | 2 +- egs/swbd/ASR/prepare.sh | 2 +- egs/tal_csasr/ASR/prepare.sh | 2 +- egs/tedlium3/ASR/prepare.sh | 2 +- egs/wenetspeech/ASR/prepare.sh | 2 +- egs/xbmu_amdo31/ASR/prepare.sh | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 4feed55a8..d36dc5ed3 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -243,7 +243,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then -lm data/lm/3-gram.unpruned.arpa fi - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then # It is used in building HLG diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 6eb6268f5..a5eb9bd13 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -159,7 +159,7 @@ fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index bd255dc6a..a23b708d7 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -293,7 +293,7 @@ fi if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then log "Stage 12: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 739608572..4a5072cc0 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -278,7 +278,7 @@ fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 8: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm diff --git a/egs/librispeech/WSASR/prepare.sh b/egs/librispeech/WSASR/prepare.sh index f6a922fde..0d2a67259 100755 --- a/egs/librispeech/WSASR/prepare.sh +++ b/egs/librispeech/WSASR/prepare.sh @@ -193,7 +193,7 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p "${lm_dir}" diff --git a/egs/mgb2/ASR/prepare.sh b/egs/mgb2/ASR/prepare.sh index 899d15d97..4ea427371 100755 --- a/egs/mgb2/ASR/prepare.sh +++ b/egs/mgb2/ASR/prepare.sh @@ -188,7 +188,7 @@ fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 8: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 47d12613b..6b6f4ff86 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -311,7 +311,7 @@ fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 8: Prepare G" lang_dir=data/lang_phone - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh index 352e8ba66..2de4ac8f5 100755 --- a/egs/tal_csasr/ASR/prepare.sh +++ b/egs/tal_csasr/ASR/prepare.sh @@ -150,7 +150,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi # Prepare words.txt - # We assume you have install jieba, if not, please install + # We assume you have installed jieba, if not, please install # it using: pip install jieba if [ ! -f $lang_char_dir/words.txt ]; then python -m jieba $lang_char_dir/text | sed 's/\///g;s/\s\+/ /g' > $lang_char_dir/text.seg diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh index 3d90436ff..2f58ca0ee 100755 --- a/egs/tedlium3/ASR/prepare.sh +++ b/egs/tedlium3/ASR/prepare.sh @@ -172,7 +172,7 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 097a59a5f..f7eb9f0d0 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -237,7 +237,7 @@ fi if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then log "Stage 17: Prepare G" # It will take about 20 minutes. - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then python3 ./shared/make_kn_lm.py \ diff --git a/egs/xbmu_amdo31/ASR/prepare.sh b/egs/xbmu_amdo31/ASR/prepare.sh index 32ae440f7..21836840c 100755 --- a/egs/xbmu_amdo31/ASR/prepare.sh +++ b/egs/xbmu_amdo31/ASR/prepare.sh @@ -224,7 +224,7 @@ fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 8: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm From dcbc7a63e117c8fdd4003bb8d998d7a0b6376aa2 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 25 Oct 2023 12:50:35 +0800 Subject: [PATCH 032/216] Update train-rnn-lm.sh (#1337) --- egs/ptb/LM/train-rnn-lm.sh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/egs/ptb/LM/train-rnn-lm.sh b/egs/ptb/LM/train-rnn-lm.sh index 29c609ee1..cb70b7856 100755 --- a/egs/ptb/LM/train-rnn-lm.sh +++ b/egs/ptb/LM/train-rnn-lm.sh @@ -37,10 +37,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then --world-size $world_size \ --use-fp16 0 \ --vocab-size 500 \ - \ --lm-data ./data/lm_training_bpe_500/sorted_lm_data.pt \ --lm-data-valid ./data/lm_training_bpe_500/sorted_lm_data-valid.pt \ - \ --embedding-dim 800 \ --hidden-dim 200 \ --num-layers 2 \ @@ -56,9 +54,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --epoch $use_epoch \ --avg $use_avg \ --vocab-size 500 \ - \ --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt \ - \ --embedding-dim 800 \ --hidden-dim 200 \ --num-layers 2 \ From 770c495484f2f244e1b54ea51fea1661f48a0a06 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 25 Oct 2023 17:14:17 +0800 Subject: [PATCH 033/216] minor fixes in the CTC decoding code (#1338) --- egs/multi_zh-hans/ASR/RESULTS.md | 4 ++-- egs/multi_zh-hans/ASR/zipformer/ctc_decode.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index 5133229a7..15e789604 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -33,8 +33,8 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the | Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | |--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| CTC Decoding | 14.57 | 15.26 | 72.85 | 69.70 | 12.87 | 13.76 | 23.56 | 25.55 | 71.75 | 22.35 | 19.34 | 42.38 | 26.90 | 48.71 | 64.88 | 67.29 | 54.24 | +| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 | | Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 | Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ diff --git a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py index a7cd7ce43..5143f945d 100755 --- a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py @@ -379,7 +379,8 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() + ref_words = list(ref_text.replace(" ", "")) + hyp_words = list("".join(hyp_words)) this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) From c0a53271e2fe64dd02939bb6e2ff3a2938715b48 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 26 Oct 2023 17:35:12 +0800 Subject: [PATCH 034/216] Update Zipformer-large result on LibriSpeech (#1343) * update zipformer-large result on librispeech --- README.md | 11 +++---- egs/librispeech/ASR/RESULTS.md | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a14abd023..81efda32a 100644 --- a/README.md +++ b/README.md @@ -118,11 +118,12 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles #### k2 pruned RNN-T -| Encoder | Params | test-clean | test-other | -|-----------------|--------|------------|------------| -| zipformer | 65.5M | 2.21 | 4.79 | -| zipformer-small | 23.2M | 2.42 | 5.73 | -| zipformer-large | 148.4M | 2.06 | 4.63 | +| Encoder | Params | test-clean | test-other | epochs | devices | +|-----------------|--------|------------|------------|---------|------------| +| zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 | +| zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 | +| zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 | +| zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index fc7fcdc26..a1808edd3 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -245,6 +245,58 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` +##### large-scaled model, number of model parameters: 148439574, i.e., 148.4 M, trained on 8 80G-A100 GPUs + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|-----------------------| +| greedy_search | 2.00 | 4.47 | --epoch 174 --avg 172 | +| modified_beam_search | 2.00 | 4.38 | --epoch 174 --avg 172 | +| fast_beam_search | 2.00 | 4.42 | --epoch 174 --avg 172 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 174 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --causal 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --full-libri 1 \ + --max-duration 2200 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search fast_beam_search; do + ./zipformer/decode.py \ + --epoch 174 \ + --avg 172 \ + --exp-dir zipformer/exp-large \ + --max-duration 600 \ + --causal 0 \ + --decoding-method $m \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 +done +``` + #### streaming ##### normal-scaled model, number of model parameters: 66110931, i.e., 66.11 M From 800bf4b6a2e32745e7d0c31dd78d473f1faff509 Mon Sep 17 00:00:00 2001 From: hairyputtar <148847552+hairyputtar@users.noreply.github.com> Date: Fri, 27 Oct 2023 09:16:28 +0530 Subject: [PATCH 035/216] fix more typos (#1340) * fix more typos * fix typo * fix typo * fix typo --- docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst | 2 +- docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst | 2 +- .../librispeech/pruned_transducer_stateless.rst | 2 +- docs/source/recipes/RNN-LM/librispeech/lm-training.rst | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst index 6e30ce397..aad90f9d0 100644 --- a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst @@ -67,7 +67,7 @@ To run stage 2 to stage 5, use: .. HINT:: A 3-gram language model will be downloaded from huggingface, we assume you have - intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by + installed and initialized ``git-lfs``. If not, you could install ``git-lfs`` by .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst index 9eb3b11f7..8e56deb6a 100644 --- a/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst @@ -67,7 +67,7 @@ To run stage 2 to stage 5, use: .. HINT:: A 3-gram language model will be downloaded from huggingface, we assume you have - intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by + installed and initialized ``git-lfs``. If not, you could install ``git-lfs`` by .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst index 1bc1dd984..f356e97e7 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -418,7 +418,7 @@ The following shows two examples (for two types of checkpoints): - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and `espnet/nets/beam_search_transducer.py `_ - is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + is used as a reference. Basically, it keeps topk states for each frame, and expands the kept states with their own contexts to next frame. - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it diff --git a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst index 736120275..46499a374 100644 --- a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst +++ b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst @@ -1,6 +1,6 @@ .. _train_nnlm: -Train an RNN langugage model +Train an RNN language model ====================================== If you have enough text data, you can train a neural network language model (NNLM) to improve From ea78b328575f9533c7d34db6f9cd0f44b09b6092 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 27 Oct 2023 13:35:43 +0800 Subject: [PATCH 036/216] minor fixes (#1345) --- egs/tedlium3/ASR/zipformer/decode.py | 4 ++-- egs/tedlium3/ASR/zipformer/train.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/tedlium3/ASR/zipformer/decode.py b/egs/tedlium3/ASR/zipformer/decode.py index ea1cbba1b..2c4123c20 100755 --- a/egs/tedlium3/ASR/zipformer/decode.py +++ b/egs/tedlium3/ASR/zipformer/decode.py @@ -116,7 +116,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -695,7 +695,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 33d03908c..5ad01df27 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -586,7 +586,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner -def get_transducer_model(params: AttributeDict) -> nn.Module: +def get_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) decoder = get_decoder_model(params) @@ -1083,7 +1083,7 @@ def run(rank, world_size, args): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") From 5cebecf2dcebbfb7284cc2577d1e50a33933c663 Mon Sep 17 00:00:00 2001 From: Shreyas0410 <70795867+Shreyas0410@users.noreply.github.com> Date: Fri, 27 Oct 2023 11:06:15 +0530 Subject: [PATCH 037/216] updated broken link in read.me file (#1342) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 81efda32a..15e9e17e6 100644 --- a/README.md +++ b/README.md @@ -367,7 +367,7 @@ Once you have trained a model in icefall, you may want to deploy it with C++, without Python dependencies. Please refer to the documentation - + for how to do this. We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++. From 7d56685734cbdd9170caae7fada2d64b27cab2b3 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Fri, 27 Oct 2023 01:38:09 -0400 Subject: [PATCH 038/216] [recipe] LibriSpeech zipformer_ctc (#941) * merge upstream * initial commit for zipformer_ctc * remove unwanted changes * remove changes to other recipe * fix zipformer softlink * fix for JIT export * add missing file * fix symbolic links * update results * Update RESULTS.md Address comments from @csukuangfj --------- Co-authored-by: zr_jin --- egs/librispeech/ASR/README.md | 1 + egs/librispeech/ASR/RESULTS.md | 51 +- egs/librispeech/ASR/zipformer_ctc/__init__.py | 0 .../ASR/zipformer_ctc/asr_datamodule.py | 1 + egs/librispeech/ASR/zipformer_ctc/decode.py | 886 +++++++++++++ egs/librispeech/ASR/zipformer_ctc/decoder.py | 298 +++++ .../ASR/zipformer_ctc/encoder_interface.py | 1 + egs/librispeech/ASR/zipformer_ctc/export.py | 240 ++++ .../ASR/zipformer_ctc/label_smoothing.py | 1 + egs/librispeech/ASR/zipformer_ctc/model.py | 158 +++ egs/librispeech/ASR/zipformer_ctc/optim.py | 1 + egs/librispeech/ASR/zipformer_ctc/scaling.py | 1 + .../ASR/zipformer_ctc/scaling_converter.py | 1 + .../ASR/zipformer_ctc/subsampling.py | 1 + egs/librispeech/ASR/zipformer_ctc/train.py | 1135 +++++++++++++++++ .../ASR/zipformer_ctc/transformer.py | 1 + .../ASR/zipformer_ctc/zipformer.py | 1 + 17 files changed, 2777 insertions(+), 1 deletion(-) create mode 100644 egs/librispeech/ASR/zipformer_ctc/__init__.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py create mode 100755 egs/librispeech/ASR/zipformer_ctc/decode.py create mode 100644 egs/librispeech/ASR/zipformer_ctc/decoder.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/encoder_interface.py create mode 100755 egs/librispeech/ASR/zipformer_ctc/export.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/label_smoothing.py create mode 100644 egs/librispeech/ASR/zipformer_ctc/model.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/optim.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/scaling.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/scaling_converter.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/subsampling.py create mode 100755 egs/librispeech/ASR/zipformer_ctc/train.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/transformer.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/zipformer.py diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index f42750da9..1c8930818 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -47,6 +47,7 @@ We place an additional Conv1d layer right after the input embedding layer. | `conformer-ctc` | Conformer | Use auxiliary attention head | | `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | | `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | +| `zipformer-ctc` | Zipformer | Use auxiliary attention head | | `zipformer` | Upgraded Zipformer | Use auxiliary transducer head | The latest recipe | # MMI diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index a1808edd3..ebf5e89c4 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -375,6 +375,55 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` +### Zipformer CTC + +#### [zipformer_ctc](./zipformer_ctc) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 86083707, i.e., 86.08 M + +| decoding method | test-clean | test-other | comment | +|-------------------------|------------|------------|---------------------| +| ctc-decoding | 2.50 | 5.86 | --epoch 30 --avg 9 | +| whole-lattice-rescoring | 2.44 | 5.38 | --epoch 30 --avg 9 | +| attention-rescoring | 2.35 | 5.16 | --epoch 30 --avg 9 | +| 1best | 2.01 | 4.61 | --epoch 30 --avg 9 | + +The training commands are: +```bash + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer_ctc/exp \ + --full-libri 1 \ + --max-duration 1000 \ + --master-port 12345 +``` + +The tensorboard log can be found at: + + +The decoding command is: + +```bash +./zipformer_ctc/decode.py \ + --epoch 30 --avg 9 --use-averaged-model True \ + --exp-dir zipformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --lm-dir data/lm \ + --method ctc-decoding +``` + ### pruned_transducer_stateless7 (Fine-tune with mux) See for more details. @@ -616,7 +665,6 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` - #### Smaller model We also provide a very small version (only 6.1M parameters) of this setup. The training command for the small model is: @@ -663,6 +711,7 @@ This small model achieves the following WERs on GigaSpeech test and dev sets: You can find the tensorboard logs at . + ### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) #### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) diff --git a/egs/librispeech/ASR/zipformer_ctc/__init__.py b/egs/librispeech/ASR/zipformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/decode.py b/egs/librispeech/ASR/zipformer_ctc/decode.py new file mode 100755 index 000000000..7f605e2c8 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/decode.py @@ -0,0 +1,886 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_ctc_model, get_params +from transformer import encoder_padding_mask + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_rnn_lm, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=77, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + nnet_output, _ = model.encoder(feature, feature_lens) + ctc_output = model.ctc_output(nnet_output) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + + nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(nnet_output.size(0), supervisions) + mask = mask.to(nnet_output.device) if mask is not None else None + mmodel = model.decoder.module if hasattr(model.decoder, "module") else model.decoder + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=mmodel, + memory=nnet_output, + memory_key_padding_mask=mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=mmodel, + memory=nnet_output, + memory_key_padding_mask=mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_ctc_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_ctc/decoder.py b/egs/librispeech/ASR/zipformer_ctc/decoder.py new file mode 100644 index 000000000..8dec048a1 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/decoder.py @@ -0,0 +1,298 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from label_smoothing import LabelSmoothingLoss +from torch.nn.utils.rnn import pad_sequence +from transformer import PositionalEncoding, TransformerDecoderLayer + + +class Decoder(nn.Module): + """This class implements Transformer based decoder for an attention-based encoder-decoder + model. + """ + + def __init__( + self, + num_layers: int, + num_classes: int, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + dropout: float = 0.1, + normalize_before: bool = True, + ): + """ + Args: + num_layers: + Number of layers. + num_classes: + Number of tokens of the modeling unit including blank. + d_model: + Dimension of the input embedding, and of the decoder output. + """ + super().__init__() + + if num_layers > 0: + self.decoder_num_class = num_classes # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + @torch.jit.export + def forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: + """Generate a length mask for input. + The masked position are filled with True, + Unmasked positions are filled with False. + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + For instance, if sz is 3, it returns:: + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + Args: + sz: mask size + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py b/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py new file mode 120000 index 000000000..b8529e0b7 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/export.py b/egs/librispeech/ASR/zipformer_ctc/export.py new file mode 100755 index 000000000..0ff50f128 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/export.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_ctc_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = get_ctc_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + convert_scaled_to_non_scaled(model, inplace=True) + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py b/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py new file mode 120000 index 000000000..08734abd7 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/model.py b/egs/librispeech/ASR/zipformer_ctc/model.py new file mode 100644 index 000000000..2aeb8a072 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/model.py @@ -0,0 +1,158 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from transformer import encoder_padding_mask + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.utils import encode_supervisions + + +class CTCModel(nn.Module): + """It implements a CTC model with an auxiliary attention head.""" + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + encoder_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + An instance of `EncoderInterface`. The shared encoder for the CTC and attention + branches + decoder: + An instance of `nn.Module`. This is the decoder for the attention branch. + encoder_dim: + Dimension of the encoder output. + decoder_dim: + Dimension of the decoder output. + vocab_size: + Number of tokens of the modeling unit including blank. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder = encoder + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + self.decoder = decoder + + @torch.jit.ignore + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + supervisions: torch.Tensor, + graph_compiler: BpeCtcTrainingGraphCompiler, + subsampling_factor: int = 1, + beam_size: int = 10, + reduction: str = "sum", + use_double_scores: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + Tensor of dimension (N, T, C) where N is the batch size, + T is the number of frames, and C is the feature dimension. + x_lens: + Tensor of dimension (N,) where N is the batch size. + supervisions: + Supervisions are used in training. + graph_compiler: + It is used to compile a decoding graph from texts. + subsampling_factor: + It is used to compute the `supervisions` for the encoder. + beam_size: + Beam size used in `k2.ctc_loss`. + reduction: + Reduction method used in `k2.ctc_loss`. + use_double_scores: + If True, use double precision in `k2.ctc_loss`. + Returns: + Return the CTC loss, attention loss, and the total number of frames. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + + nnet_output, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + # compute ctc log-probs + ctc_output = self.ctc_output(nnet_output) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=subsampling_factor + ) + num_frames = supervision_segments[:, 2].sum().item() + + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments.cpu(), + allow_truncate=subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=beam_size, + reduction=reduction, + use_double_scores=use_double_scores, + ) + + if self.decoder is not None: + nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mmodel = ( + self.decoder.module if hasattr(self.decoder, "module") else self.decoder + ) + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + mask = encoder_padding_mask(nnet_output.size(0), supervisions) + mask = mask.to(nnet_output.device) if mask is not None else None + att_loss = mmodel.forward( + nnet_output, + mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + else: + att_loss = torch.tensor([0]) + + return ctc_loss, att_loss, num_frames diff --git a/egs/librispeech/ASR/zipformer_ctc/optim.py b/egs/librispeech/ASR/zipformer_ctc/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/scaling.py b/egs/librispeech/ASR/zipformer_ctc/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py b/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/subsampling.py b/egs/librispeech/ASR/zipformer_ctc/subsampling.py new file mode 120000 index 000000000..6fee09e58 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/subsampling.py @@ -0,0 +1 @@ +../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py new file mode 100755 index 000000000..f40344357 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -0,0 +1,1135 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./zipformer_ctc/train.py \ + --exp-dir ./zipformer_ctc/exp \ + --world-size 4 \ + --full-libri 1 \ + --max-duration 500 \ + --num-epochs 30 +""" + +import argparse +import copy +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import CTCModel +from optim import Eden, LRScheduler, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + num_layers=params.num_decoder_layers, + num_classes=params.vocab_size, + d_model=int(params.encoder_dims.split(",")[-1]), + ) + return decoder + + +def get_ctc_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + + model = CTCModel( + encoder=encoder, + decoder=decoder, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + with torch.set_grad_enabled(is_training): + ctc_loss, att_loss, tot_frames = model( + feature, + feature_lens, + supervisions, + graph_compiler, + subsampling_factor=params.subsampling_factor, + beam_size=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + info = MetricsTracker() + info["frames"] = tot_frames + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + assert loss.requires_grad == is_training, f"{loss.requires_grad} != {is_training}" + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + logging.info("About to create model") + + model = get_ctc_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 25.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_ctc/transformer.py b/egs/librispeech/ASR/zipformer_ctc/transformer.py new file mode 120000 index 000000000..4c890cf29 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/transformer.py @@ -0,0 +1 @@ +../conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/zipformer.py b/egs/librispeech/ASR/zipformer_ctc/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file From 161ab90dfb951a49c7fb373861317fdcb9a9e7e4 Mon Sep 17 00:00:00 2001 From: Himanshu Kumar Mahto <93067059+HimanshuMahto@users.noreply.github.com> Date: Mon, 30 Oct 2023 06:37:42 +0530 Subject: [PATCH 039/216] Enhancing the contributing.md file (#1351) --- contributing.md | 58 ++++++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/contributing.md b/contributing.md index c8f06fdae..0a1f9936e 100644 --- a/contributing.md +++ b/contributing.md @@ -1,39 +1,37 @@ +# Contributing to Our Project -## Pre-commit hooks +Thank you for your interest in contributing to our project! We use Git pre-commit hooks to ensure code quality and consistency. Before contributing, please follow these guidelines to enable and use the pre-commit hooks. -We use [git][git] [pre-commit][pre-commit] [hooks][hooks] to check that files -going to be committed: +## Pre-Commit Hooks - - contain no trailing spaces - - are formatted with [black][black] - - are compatible to [PEP8][PEP8] (checked by [flake8][flake8]) - - end in a newline and only a newline - - contain sorted `imports` (checked by [isort][isort]) +We have set up pre-commit hooks to check that the files you're committing meet our coding and formatting standards. These checks include: -These hooks are disabled by default. Please use the following commands to enable them: +- Ensuring there are no trailing spaces. +- Formatting code with [black](https://github.com/psf/black). +- Checking compliance with PEP8 using [flake8](https://flake8.pycqa.org/). +- Verifying that files end with a newline character (and only a newline). +- Sorting imports using [isort](https://pycqa.github.io/isort/). -```bash -pip install pre-commit # run it only once -pre-commit install # run it only once, it will install all hooks +Please note that these hooks are disabled by default. To enable them, follow these steps: -# modify some files -git add -git commit # It runs all hooks automatically. +### Installation (Run only once) -# If all hooks run successfully, you can write the commit message now. Done! -# -# If any hook failed, your commit was not successful. -# Please read the error messages and make changes accordingly. -# And rerun +1. Install the `pre-commit` package using pip: + ```bash + pip install pre-commit + ``` +1. Install the Git hooks using: + ```bash + pre-commit install + ``` +### Making a Commit +Once you have enabled the pre-commit hooks, follow these steps when making a commit: +1. Make your changes to the codebase. +2. Stage your changes by using git add for the files you modified. +3. Commit your changes using git commit. The pre-commit hooks will run automatically at this point. +4. If all hooks run successfully, you can write your commit message, and your changes will be successfully committed. +5. If any hook fails, your commit will not be successful. Please read and follow the error messages provided, make the necessary changes, and then re-run git add and git commit. -git add -git commit -``` +### Your Contribution +Your contributions are valuable to us, and by following these guidelines, you help maintain code consistency and quality in our project. We appreciate your dedication to ensuring high-quality code. If you have questions or need assistance, feel free to reach out to us. Thank you for being part of our open-source community! -[git]: https://git-scm.com/book/en/v2/Customizing-Git-Git-Hooks -[flake8]: https://github.com/PyCQA/flake8 -[PEP8]: https://www.python.org/dev/peps/pep-0008/ -[black]: https://github.com/psf/black -[hooks]: https://github.com/pre-commit/pre-commit-hooks -[pre-commit]: https://github.com/pre-commit/pre-commit -[isort]: https://github.com/PyCQA/isort From c970df512b189b147e7fe6d45a7e8eb8609b9415 Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Mon, 30 Oct 2023 12:09:39 +0800 Subject: [PATCH 040/216] New recipe: tiny_transducer_ctc (#848) * initial commit * update readme * Update README.md * change bool to str2bool for arg parser * run validation only at the end of epoch * black format * black format --- .../ASR/tiny_transducer_ctc/README.md | 184 +++ .../ASR/tiny_transducer_ctc/asr_datamodule.py | 454 ++++++ .../ASR/tiny_transducer_ctc/beam_search.py | 1 + .../ASR/tiny_transducer_ctc/ctc_decode.py | 770 ++++++++++ .../ASR/tiny_transducer_ctc/decode.py | 717 ++++++++++ .../ASR/tiny_transducer_ctc/decoder.py | 1 + .../ASR/tiny_transducer_ctc/encoder.py | 379 +++++ .../tiny_transducer_ctc/encoder_interface.py | 1 + .../ASR/tiny_transducer_ctc/export.py | 316 +++++ .../ASR/tiny_transducer_ctc/jit_pretrained.py | 271 ++++ .../tiny_transducer_ctc/jit_pretrained_ctc.py | 426 ++++++ .../ASR/tiny_transducer_ctc/joiner.py | 1 + .../ASR/tiny_transducer_ctc/model.py | 1 + .../ASR/tiny_transducer_ctc/pretrained.py | 357 +++++ .../ASR/tiny_transducer_ctc/pretrained_ctc.py | 444 ++++++ .../ASR/tiny_transducer_ctc/scaling.py | 1 + .../ASR/tiny_transducer_ctc/train.py | 1251 +++++++++++++++++ 17 files changed, 5575 insertions(+) create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/README.md create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/decode.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/decoder.py create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/encoder.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/export.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/joiner.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/model.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/scaling.py create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/train.py diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/README.md b/egs/librispeech/ASR/tiny_transducer_ctc/README.md new file mode 100644 index 000000000..78dbc12c9 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/README.md @@ -0,0 +1,184 @@ +## Introduction + +This recipe is intended for streaming ASR on very low cost devices, with model parameters in the range of 1-2M. It uses a small convolutional net as the encoder. It is trained with combined transducer and CTC losses, and supports both phone and BPE lexicons. For phone lexicon, you can do transducer decoding using a method with LG, but the results were bad. + +The encoder consists of 2 subsampling layers followed by a stack of Conv1d-batchnorm-activation-causal_squeeze_excite blocks, with optional skip connections. To reduce latency (at the cost of slightly higher WER), half of the blocks use causal convolution. + +A few remarks & observations: + +1. Phone lexicon works better than BPE for CTC decoding (with HLG) but worse for transducer decoding. + +2. SpecAugment is not helpful for very small models as they tend to underfit rather than overfit. For the large model, a less aggressive SpecAugment (see asr_datamodule.py) improved the result a little. + +3. Squeeze-and-excitation worked like a charm! It reduces WER quite a bit with marginal increase of parameters and MAC ops. To make it causal I changed the global average pooling layer to a moving average filter, so only historical context is used. + +## Pretrained models + +You can find pretrained models, training logs, decoding logs, and decoding results at: + + +## Results on full libri + +I tried 3 different sizes of the encoder. The parameters are around 1M, 2M and 4M, respectively. For CTC decoding, whole-lattice-rescoring frequently causes OOM error so the result is not shown. + +### Small encoder + +The small encoder uses 10 layers of 1D convolution block with 256 channels, without skip connections. The encoder, decoder and joiner dim is 256. Algorithmic latency is 280ms. Multiply-add ops for the encoder is 22.0Mops. It is more applicable for ASR products with limited vocabulary (like a fixed set of phrases or short sentences). + +#### CTC decoding with phone lexicon +Total parameters: 1073392 + +Parameters for CTC decoding: 865816 + +| | test-clean | test-other | comment | +|-----------------|------------|------------|----------------------| +| 1best | 9.68 | 24.9 | --epoch 30 --avg 2 | +| nbest-rescoring | 8.2 | 22.7 | --epoch 30 --avg 2 | + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_small_phone \ + --ctc-loss-scale 0.7 \ + --enable-spec-aug 0 \ + --lang-dir lang_phone \ + --encoder-dim 256 \ + --decoder-dim 256 \ + --joiner-dim 256 \ + --conv-layers 10 \ + --channels 256 \ + --skip-add 0 \ +``` + +#### Transducer decoding with BPE 500 lexicon +Total parameters: 1623264 + +Parameters for transducer decoding: 1237764 + +| | test-clean | test-other | comment | +|--------------------|------------|------------|----------------------| +| greedy_search | 14.47 | 32.03 | --epoch 30 --avg 1 | +| fast_beam_search | 13.38 | 29.61 | --epoch 30 --avg 1 | +|modified_beam_search| 13.02 | 29.32 | --epoch 30 --avg 1 | + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_small_bpe \ + --ctc-loss-scale 0.2 \ + --enable-spec-aug 0 \ + --lang-dir lang_bpe_500 \ + --encoder-dim 256 \ + --decoder-dim 256 \ + --joiner-dim 256 \ + --conv-layers 10 \ + --channels 256 \ + --skip-add 0 \ +``` + +### Middle encoder + +The middle encoder uses 18 layers of 1D convolution block with 300 channels, with skip connections. The encoder, decoder and joiner dim is 256. Algorithmic latency is 440ms. Multiply-add ops for the encoder is 50.1Mops. Note that the nbest-rescoring result is better than the tdnn_lstm_ctc recipe with whole-lattice-rescoring. + +#### CTC decoding with phone lexicon +Total parameters: 2186242 + +Parameters for CTC decoding: 1978666 + +| | test-clean | test-other | comment | +|-----------------|------------|------------|----------------------| +| 1best | 7.48 | 18.94 | --epoch 30 --avg 1 | +| nbest-rescoring | 6.31 | 16.89 | --epoch 30 --avg 1 | + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_middle_phone \ + --ctc-loss-scale 0.7 \ + --enable-spec-aug 0 \ + --lang-dir lang_phone \ + --encoder-dim 256 \ + --decoder-dim 256 \ + --joiner-dim 256 \ + --conv-layers 18 \ + --channels 300 \ + --skip-add 1 \ +``` + +#### Transducer decoding with BPE 500 lexicon +Total parameters: 2735794 + +Parameters for transducer decoding: 2350294 + +| | test-clean | test-other | comment | +|--------------------|------------|------------|----------------------| +| greedy_search | 10.26 | 25.13 | --epoch 30 --avg 2 | +| fast_beam_search | 9.69 | 23.58 | --epoch 30 --avg 2 | +|modified_beam_search| 9.43 | 23.53 | --epoch 30 --avg 2 | + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_middle_bpe \ + --ctc-loss-scale 0.2 \ + --enable-spec-aug 0 \ + --lang-dir lang_bpe_500 \ + --encoder-dim 256 \ + --decoder-dim 256 \ + --joiner-dim 256 \ + --conv-layers 18 \ + --channels 300 \ + --skip-add 1 \ +``` + +### Large encoder + +The large encoder uses 18 layers of 1D convolution block with 400 channels, with skip connections. The encoder, decoder and joiner dim is 400. Algorithmic latency is 440ms. Multiply-add ops for the encoder is 88.8Mops. It is interesting to see how much the gap is if we simply scale down more complicated models like Zipformer or emformer. + + +#### Transducer decoding with BPE 500 lexicon +Total parameters: 4821330 + +Parameters for transducer decoding: 4219830 + +| | test-clean | test-other | comment | +|--------------------|------------|------------|----------------------| +| greedy_search | 8.29 | 21.11 | --epoch 30 --avg 1 | +| fast_beam_search | 7.91 | 20.1 | --epoch 30 --avg 1 | +|modified_beam_search| 7.74 | 19.89 | --epoch 30 --avg 1 | + + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_large_bpe \ + --ctc-loss-scale 0.2 \ + --enable-spec-aug 1 \ + --lang-dir lang_bpe_500 \ + --encoder-dim 400 \ + --decoder-dim 400 \ + --joiner-dim 400 \ + --conv-layers 18 \ + --channels 400 \ + --skip-add 1 \ +``` diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py new file mode 100644 index 000000000..8facb6dba --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py @@ -0,0 +1,454 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=False, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=0, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_feature_masks=2, + features_mask_size=5, + num_frame_masks=10, + frames_mask_size=5, + p=0.5, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py b/egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py new file mode 100644 index 000000000..402aeac0c --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -0,0 +1,770 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import pprint +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="tiny_transducer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_phone", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="1best", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.7, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "context_size": 2, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, _ = model.encoder(feature, feature_lens) + nnet_output = model.ctc_output(encoder_out) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="trunc", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="trunc", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + # lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + # lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + # lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + lm_scale_list = [0.6, 0.7, 0.8, 0.9] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-hlg-scale-{params.hlg_scale}" + + if params.use_averaged_model: + params.suffix += "-uam" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(pprint.pformat(params, indent=2)) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + assert "lang_bpe" in str( + params.lang_dir + ), "ctc-decoding only supports BPE lexicons." + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + enc_param = sum([p.numel() for p in model.encoder.parameters()]) + ctc_param = sum([p.numel() for p in model.ctc_output.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + logging.info(f"Parameters for CTC decoding: {enc_param + ctc_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py new file mode 100644 index 000000000..6c2bf9ea1 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -0,0 +1,717 @@ +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import pprint +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="tiny_transducer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="fast_beam_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help=""" + Used only when --decoding_method is fast_beam_search_LG or + fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + if "lang_phone" in str(params.lang_dir): + assert params.decoding_method in ( + "fast_beam_search_LG", + "fast_beam_search_nbest_LG", + ), "For phone lexicon, use a decoding method with LG." + + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-uam" + + setup_logger(f"{params.res_dir}/log-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + if "lang_bpe" in str(params.lang_dir): + sp = spm.SentencePieceProcessor() + sp.load(str(params.lang_dir / "bpe.model")) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + else: + params.blank_id = lexicon.token_table.get("") + params.unk_id = lexicon.token_table.get("SPN") + params.vocab_size = max(lexicon.tokens) + 1 + sp = None + + logging.info(pprint.pformat(params, indent=2)) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + enc_param = sum([p.numel() for p in model.encoder.parameters()]) + dec_param = sum([p.numel() for p in model.decoder.parameters()]) + join_param = sum([p.numel() for p in model.joiner.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + logging.info( + f"Parameters for transducer decoding: {enc_param + dec_param + join_param}" + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py new file mode 100644 index 000000000..4c7fca4fc --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022 Spacetouch Inc. (author: Tiance Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn.functional as F +from encoder_interface import EncoderInterface +from scaling import ActivationBalancer, DoubleSwish +from torch import Tensor, nn + + +class Conv1dNet(EncoderInterface): + """ + 1D Convolution network with causal squeeze and excitation + module and optional skip connections. + + Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride. + + Args: + output_dim (int): Number of output channels of the last layer. + input_dim (int): Number of input features + conv_layers (int): Number of convolution layers, + excluding the subsampling layers. + channels (int): Number of output channels for each layer, + except the last layer. + subsampling_factor (int): The subsampling factor for the model. + skip_add (bool): Whether to use skip connection for each convolution layer. + dscnn (bool): Whether to use depthwise-separated convolution. + activation (str): Activation function type. + """ + + def __init__( + self, + output_dim: int, + input_dim: int = 80, + conv_layers: int = 10, + channels: int = 256, + subsampling_factor: int = 4, + skip_add: bool = False, + dscnn: bool = True, + activation: str = "relu", + ) -> None: + super().__init__() + assert subsampling_factor == 4, "Only support subsampling = 4" + + self.conv_layers = conv_layers + self.skip_add = skip_add + # 80ms latency for subsample_layer + self.subsample_layer = nn.Sequential( + conv1d_bn_block( + input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn + ), + conv1d_bn_block( + channels, channels, 5, stride=2, activation=activation, dscnn=dscnn + ), + ) + + self.conv_blocks = nn.ModuleList() + cin = [channels] * conv_layers + cout = [channels] * (conv_layers - 1) + [output_dim] + + # Use causal and standard convolution alternatively + for ly in range(conv_layers): + self.conv_blocks.append( + nn.Sequential( + conv1d_bn_block( + cin[ly], + cout[ly], + 3, + activation=activation, + dscnn=dscnn, + causal=ly % 2, + ), + CausalSqueezeExcite1d(cout[ly], 16, 30), + ) + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.subsample_layer(x) + for idx, layer in enumerate(self.conv_blocks): + if self.skip_add and 0 < idx < self.conv_layers - 1: + x = layer(x) + x + else: + x = layer(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + lengths = x_lens >> 2 + return x, lengths + + +def get_activation( + name: str, + channels: int, + channel_dim: int = -1, + min_val: int = 0, + max_val: int = 1, +) -> nn.Module: + """ + Get activation function from name in string. + + Args: + name: activation function name + channels: only used for PReLU, should be equal to x.shape[1]. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + e.g. for NCHW tensor, channel_dim = 1 + min_val: minimum value of hardtanh + max_val: maximum value of hardtanh + + Returns: + The activation function module + + """ + act_layer = nn.Identity() + name = name.lower() + if name == "prelu": + act_layer = nn.PReLU(channels) + elif name == "relu": + act_layer = nn.ReLU() + elif name == "relu6": + act_layer = nn.ReLU6() + elif name == "hardtanh": + act_layer = nn.Hardtanh(min_val, max_val) + elif name in ["swish", "silu"]: + act_layer = nn.SiLU() + elif name == "elu": + act_layer = nn.ELU() + elif name == "doubleswish": + act_layer = nn.Sequential( + ActivationBalancer(num_channels=channels, channel_dim=channel_dim), + DoubleSwish(), + ) + elif name == "": + act_layer = nn.Identity() + else: + raise Exception(f"Unknown activation function: {name}") + + return act_layer + + +class CausalSqueezeExcite1d(nn.Module): + """ + Causal squeeze and excitation module with input and output shape + (batch, channels, time). The global average pooling in the original + SE module is replaced by a causal filter, so + the layer does not introduce any algorithmic latency. + + Args: + channels (int): Number of channels + reduction (int): channel reduction rate + context_window (int): Context window size for the moving average operation. + For EMA, the smoothing factor is 1 / context_window. + """ + + def __init__( + self, + channels: int, + reduction: int = 16, + context_window: int = 10, + ) -> None: + super(CausalSqueezeExcite1d, self).__init__() + + assert channels >= reduction + + self.context_window = context_window + c_squeeze = channels // reduction + self.linear1 = nn.Linear(channels, c_squeeze, bias=True) + self.act1 = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(c_squeeze, channels, bias=True) + self.act2 = nn.Sigmoid() + + # EMA worked better than MA empirically + # self.avg_filter = self.moving_avg + self.avg_filter = self.exponential_moving_avg + self.ema_matrix = torch.tensor([0]) + self.ema_matrix_size = 0 + + def _precompute_ema_matrix(self, N: int, device: torch.device): + a = 1.0 / self.context_window # smoothing factor + w = [[(1 - a) ** k * a for k in range(n, n - N, -1)] for n in range(N)] + w = torch.tensor(w).to(device).tril() + w[:, 0] *= self.context_window + self.ema_matrix = w.T + self.ema_matrix_size = N + + def exponential_moving_avg(self, x: Tensor) -> Tensor: + """ + Exponential moving average filter, which is calculated as: + y[t] = (1-a) * y[t-1] + a * x[t] + where a = 1 / self.context_window is the smoothing factor. + + For training, the iterative version is too slow. A better way is + to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication. + The weight matrix can be precomputed if the smoothing factor is fixed. + """ + if self.training: + # use matrix version to speed up training + N = x.shape[-1] + if N > self.ema_matrix_size: + self._precompute_ema_matrix(N, x.device) + y = torch.matmul(x, self.ema_matrix[:N, :N]) + else: + # use iterative version to save memory + a = 1.0 / self.context_window + y = torch.empty_like(x) + y[:, :, 0] = x[:, :, 0] + for t in range(1, y.shape[-1]): + y[:, :, t] = (1 - a) * y[:, :, t - 1] + a * x[:, :, t] + return y + + def moving_avg(self, x: Tensor) -> Tensor: + """ + Simple moving average with context_window as window size. + """ + y = torch.empty_like(x) + k = min(x.shape[2], self.context_window) + w = [[1 / n] * n + [0] * (k - n - 1) for n in range(1, k)] + w = torch.tensor(w, device=x.device) + y[:, :, : k - 1] = torch.matmul(x[:, :, : k - 1], w.T) + y[:, :, k - 1 :] = F.avg_pool1d(x, k, 1) + return y + + def forward(self, x: Tensor) -> Tensor: + + assert len(x.shape) == 3, "Input is not a 3D tensor!" + y = self.exponential_moving_avg(x) + y = y.permute(0, 2, 1) # make channel last for squeeze op + y = self.act1(self.linear1(y)) + y = self.act2(self.linear2(y)) + y = y.permute(0, 2, 1) # back to original shape + y = x * y + return y + + +def conv1d_bn_block( + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + activation: str = "relu", + dscnn: bool = False, + causal: bool = False, +) -> nn.Sequential: + """ + Conv1d - batchnorm - activation block. + If kernel size is even, output length = input length + 1. + Otherwise, output and input lengths are equal. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + kernel_size (int): kernel size + stride (int): convolution stride + dilation (int): convolution dilation rate + dscnn (bool): Use depthwise separated convolution. + causal (bool): Use causal convolution + activation (str): Activation function type. + + """ + if dscnn: + return nn.Sequential( + CausalConv1d( + in_channels, + in_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=in_channels, + bias=False, + ) + if causal + else nn.Conv1d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=(kernel_size // 2) * dilation, + dilation=dilation, + groups=in_channels, + bias=False, + ), + nn.BatchNorm1d(in_channels), + get_activation(activation, in_channels), + nn.Conv1d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm1d(out_channels), + get_activation(activation, out_channels), + ) + else: + return nn.Sequential( + CausalConv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + bias=False, + ) + if causal + else nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=(kernel_size // 2) * dilation, + dilation=dilation, + bias=False, + ), + nn.BatchNorm1d(out_channels), + get_activation(activation, out_channels), + ) + + +class CausalConv1d(nn.Module): + """ + Causal convolution with padding automatically chosen to match input/output length. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ) -> None: + super(CausalConv1d, self).__init__() + assert kernel_size > 2 + + self.padding = dilation * (kernel_size - 1) + self.stride = stride + + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + self.padding, + dilation, + groups, + bias=bias, + ) + + def forward(self, x: Tensor) -> Tensor: + return self.conv(x)[:, :, : -self.padding // self.stride] diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/export.py b/egs/librispeech/ASR/tiny_transducer_ctc/export.py new file mode 100755 index 000000000..4117f7244 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/export.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 30 \ + --avg 2 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 30 \ + --avg 2 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `tiny_transducer_ctc/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./tiny_transducer_ctc/decode.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --epoch 9999 \ + --use-averaged-model 0 + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --lang-dir data/lang_bpe_500 \ + +Check ./pretrained.py for its usage. + +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import UniqLexicon +from icefall.utils import str2bool +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="tiny_transducer_ctc/exp_4m_bpe500_halfdelay_specaug", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if "lang_bpe" in str(params.lang_dir): + sp = spm.SentencePieceProcessor() + sp.load(params.lang_dir + "/bpe.model") + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + else: + assert "lang_phone" in str(params.lang_dir) + phone_lexicon = UniqLexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(phone_lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py new file mode 100755 index 000000000..3888d3544 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./tiny_transducer_ctc/jit_pretrained.py \ + --nn-model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py new file mode 100755 index 000000000..6f2cbaabd --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(3) nbest-rescoring +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/joiner.py b/egs/librispeech/ASR/tiny_transducer_ctc/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/model.py b/egs/librispeech/ASR/tiny_transducer_ctc/model.py new file mode 120000 index 000000000..545af927f --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_ctc/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py new file mode 100755 index 000000000..981039b8f --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./tiny_transducer_ctc/pretrained.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./tiny_transducer_ctc/pretrained.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./tiny_transducer_ctc/pretrained.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./tiny_transducer_ctc/pretrained.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./tiny_transducer_ctc/exp/epoch-xx.pt`. + +Note: ./tiny_transducer_ctc/exp/pretrained.pt is generated by +./tiny_transducer_ctc/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.lang_dir + "/bpe.model") + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py new file mode 100755 index 000000000..a06d6d684 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py @@ -0,0 +1,444 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) ctc-decoding +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + params.blank_id = 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/scaling.py b/egs/librispeech/ASR/tiny_transducer_ctc/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py new file mode 100644 index 000000000..307ad72aa --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -0,0 +1,1251 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +cd egs/librispeech/ASR/ +./prepare.sh + +Run below if you want to use the phone lexicon instead of BPE: +python local/generate_unique_lexicon.py --lang-dir data/lang_phone + +""" +import argparse +import copy +import logging +import pprint +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from encoder import Conv1dNet +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW +from torch.optim.lr_scheduler import StepLR +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics, is_module_available +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import UniqLexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + + parser.add_argument( + "--encoder-dim", + type=int, + default=256, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=256, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=256, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--conv-layers", + type=int, + default=10, + help="""Number of convolution layers for the encoder. + """, + ) + + parser.add_argument( + "--channels", + type=int, + default=256, + help="""Number of channels for the encoder. + """, + ) + + parser.add_argument( + "--skip-add", + type=str2bool, + default=False, + help="""Use skip connection in the encoder. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="tiny_transducer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="""Weight for CTC loss, between 0 and 1. + When set to 0, only transducer loss is used. + When set to 1, only CTC loss is used.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=5, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 500, + "reset_interval": 200, + "warm_step": 5000, + "beam_size": 10, + "use_double_scores": True, + "env_info": get_env_info(), + "feature_dim": 80, + "subsampling_factor": 4, + "use_dscnn": True, + "activation": "doubleswish", + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + + encoder = Conv1dNet( + output_dim=params.encoder_dim, + input_dim=params.feature_dim, + conv_layers=params.conv_layers, + channels=params.channels, + subsampling_factor=params.subsampling_factor, + skip_add=params.skip_add, + dscnn=params.use_dscnn, + activation=params.activation, + ) + + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> Transducer: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + if is_module_available("thop"): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + # Assuming 10ms stride, 1000 frames is about 10 seconds. + x = torch.zeros((1, 1000, params.feature_dim)).to(device) + x_lens = torch.Tensor([1000]).int().to(device) + from thop import clever_format, profile + + m = copy.deepcopy(encoder) + m = m.to(device) + ops, _ = clever_format(profile(m, (x, x_lens), verbose=False)) + logging.info(f"Encoder MAC ops for 10 seconds of audio is {ops}") + else: + logging.info("You can install thop to calculate the number of ops.") + logging.info("Command: pip install thop") + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + if sp is not None: + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) + else: + y = phone_lexicon.texts_to_token_ids(texts).to(device) + token_ids = y.tolist() + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_output = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + assert 0 <= params.ctc_loss_scale <= 1, "ctc_loss_scale must be between 0 and 1" + loss = params.ctc_loss_scale * ctc_loss + (1 - params.ctc_loss_scale) * loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + # scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch( + batch, params=params, sp=sp, phone_lexicon=phone_lexicon + ) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + if "lang_bpe" in str(params.lang_dir): + sp = spm.SentencePieceProcessor() + sp.load(params.lang_dir + "/bpe.model") + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + phone_lexicon = None + else: + assert "lang_phone" in str(params.lang_dir) + phone_lexicon = UniqLexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(phone_lexicon.tokens) + 1 + sp = None + + logging.info(pprint.pformat(params, indent=2)) + + logging.info("About to create model") + model = get_transducer_model(params) + + if rank == 0: + num_param = sum([p.numel() for p in model.parameters()]) + enc_param = sum([p.numel() for p in model.encoder.parameters()]) + dec_param = sum([p.numel() for p in model.decoder.parameters()]) + join_param = sum([p.numel() for p in model.joiner.parameters()]) + ctc_param = sum([p.numel() for p in model.ctc_output.parameters()]) + + logging.info(f"Number of model parameters: {num_param}") + logging.info(f"Number of encoder parameters: {enc_param}") + logging.info(f"Number of decoder parameters: {dec_param}") + logging.info(f"Number of joiner parameters: {join_param}") + logging.info(f"Number of ctc parameters: {ctc_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = AdamW( + model.parameters(), + lr=params.initial_lr, + weight_decay=5e-4, + ) + + scheduler = StepLR(optimizer, step_size=2, gamma=0.8) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # phone_lexicon=phone_lexicon, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + # scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + phone_lexicon=phone_lexicon, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + if sp is not None: + y = sp.encode(supervisions["text"], out_type=int) + else: + y = phone_lexicon.texts_to_token_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch( + batch, params=params, sp=sp, phone_lexicon=phone_lexicon + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 23913f6afdea59caf703e3ac715852810cd246ad Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 31 Oct 2023 10:28:20 +0800 Subject: [PATCH 041/216] Minor refinements for some stale but recently merged PRs (#1354) * incorporate https://github.com/k2-fsa/icefall/pull/1269 * incorporate https://github.com/k2-fsa/icefall/pull/1301 * black formatted * incorporate https://github.com/k2-fsa/icefall/pull/1162 * black formatted --- egs/aishell/ASR/zipformer/train.py | 2 +- .../ASR/zipformer/asr_datamodule.py | 2 +- egs/gigaspeech/ASR/zipformer/train.py | 2 +- .../ASR/zipformer_prompt_asr/optim.py | 12 +++---- .../zipformer_prompt_asr/train_baseline.py | 2 +- .../train_bert_encoder.py | 2 +- .../ASR/tiny_transducer_ctc/asr_datamodule.py | 8 ++--- .../ASR/tiny_transducer_ctc/ctc_decode.py | 3 +- .../ASR/tiny_transducer_ctc/encoder.py | 1 - .../ASR/tiny_transducer_ctc/export.py | 31 ++++++----------- .../ASR/tiny_transducer_ctc/train.py | 4 +-- egs/librispeech/ASR/zipformer_ctc/export.py | 21 +++++------- egs/librispeech/ASR/zipformer_ctc/train.py | 2 +- icefall/diagnostics.py | 33 ++++++++++++------- 14 files changed, 57 insertions(+), 68 deletions(-) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index 7e7b02829..d381649e4 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index c4472ed23..6adfdbfbb 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -233,7 +233,7 @@ class GigaSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index d8ff4fecc..d93cc221c 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -1164,7 +1164,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py index a767761eb..159e363c7 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index 32302602c..c8b20d021 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -1194,7 +1194,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index e253d1118..9822b99c1 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -1565,7 +1565,7 @@ def run(rank, world_size, args): if params.print_diagnostics: args.max_duration = 100 opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py index 8facb6dba..3acd22ae4 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -225,7 +225,7 @@ class LibriSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -307,8 +307,8 @@ class LibriSpeechAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py index 402aeac0c..cda03b56e 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -22,10 +22,11 @@ import argparse import logging import math +import pprint from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import pprint + import k2 import sentencepiece as spm import torch diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py index 4c7fca4fc..afdd00293 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py @@ -253,7 +253,6 @@ class CausalSqueezeExcite1d(nn.Module): return y def forward(self, x: Tensor) -> Tensor: - assert len(x.shape) == 3, "Input is not a 3D tensor!" y = self.exponential_moving_avg(x) y = y.permute(0, 2, 1) # make channel last for squeeze op diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/export.py b/egs/librispeech/ASR/tiny_transducer_ctc/export.py index 4117f7244..334dd011e 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/export.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/export.py @@ -76,17 +76,17 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from train import add_model_arguments, get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) -from icefall.lexicon import UniqLexicon -from icefall.utils import str2bool -from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -143,13 +143,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -189,17 +186,9 @@ def main(): logging.info(f"device: {device}") - if "lang_bpe" in str(params.lang_dir): - sp = spm.SentencePieceProcessor() - sp.load(params.lang_dir + "/bpe.model") - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - else: - assert "lang_phone" in str(params.lang_dir) - phone_lexicon = UniqLexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(phone_lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 307ad72aa..8920764cd 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -89,7 +89,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( "--encoder-dim", type=int, @@ -405,7 +404,6 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conv1dNet( output_dim=params.encoder_dim, input_dim=params.feature_dim, @@ -1043,7 +1041,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/zipformer_ctc/export.py b/egs/librispeech/ASR/zipformer_ctc/export.py index 0ff50f128..4c46aea2c 100755 --- a/egs/librispeech/ASR/zipformer_ctc/export.py +++ b/egs/librispeech/ASR/zipformer_ctc/export.py @@ -23,6 +23,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -33,8 +34,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -90,11 +90,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -113,17 +112,15 @@ def get_parser(): def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - logging.info(params) + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - params.vocab_size = num_classes + logging.info(params) device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index f40344357..60990456d 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -947,7 +947,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index ebf61784e..65b6f67b0 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -244,16 +244,14 @@ class TensorDiagnostic(object): if stats_type == "eigs": try: - if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'): + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): eigs, _ = torch.linalg.eigh(stats) else: eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print( - "Error getting eigenvalues, trying another method." - ) - if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'): + print("Error getting eigenvalues, trying another method.") + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"): eigs, _ = torch.linalg.eig(stats) eigs = eigs.abs() else: @@ -579,10 +577,15 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=get_class_name(_module)) - + if isinstance(o, Tensor) and o.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] @@ -596,9 +599,15 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=get_class_name(_module)) + if isinstance(o, Tensor) and o.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) + module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) From 9e5a5d7839aa3052e46dcf25b239a37449f8cd5e Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 2 Nov 2023 16:10:08 +0800 Subject: [PATCH 042/216] Incorporate some latest changes to `optim.py` (#1359) * init commit * black formatted * isort formatted --- egs/librispeech/ASR/zipformer/optim.py | 171 +++++++++++++++++-------- 1 file changed, 121 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8ee2b0eb4..a663db708 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch from lhotse.utils import fix_random_seed -from torch import Tensor +from torch import Tensor, nn from torch.optim import Optimizer @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for (stacked_params, _state, _names), batch in zip(tuples, batches): + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,6 +181,7 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): + defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -326,7 +327,9 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): + with self.batched_params(group["params"], group_params_names) as batches: + # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -423,16 +426,19 @@ class ScaledAdam(BatchedOptimizer): # parameters' state won't have been initialized yet. return 1.0 clipping_update_period = group["clipping_update_period"] + scalar_lr_scale = group["scalar_lr_scale"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for p, state, param_names in tuples: + for (p, state, param_names) in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + tot_sumsq += (grad**2).sum() * ( + scalar_lr_scale**2 + ) # sum() to change shape [1] to [] else: tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() @@ -443,64 +449,72 @@ class ScaledAdam(BatchedOptimizer): ) first_state["model_norms"][step % clipping_update_period] = tot_norm - if step % clipping_update_period == 0: + irregular_estimate_steps = [ + i for i in [10, 20, 40] if i < clipping_update_period + ] + if step % clipping_update_period == 0 or step in irregular_estimate_steps: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + if step in irregular_estimate_steps: + sorted_norms = sorted_norms[-step:] + num_norms = sorted_norms.numel() quartiles = [] for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) + index = min(num_norms - 1, (num_norms // 4) * n) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median + if step in irregular_estimate_steps: + # use larger thresholds on first few steps of estimating threshold, + # as norm may be changing rapidly. + threshold = threshold * 2.0 first_state["model_norm_threshold"] = threshold percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period + first_state["num_clipped"] * 100.0 / num_norms if "num_clipped" in first_state else 0.0 ) first_state["num_clipped"] = 0 quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( + logging.warn( f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" ) - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - if ans != ans: # e.g. ans is nan - ans = 0.0 - if ans == 0.0: - for p, state, param_names in tuples: - p.grad.zero_() # get rid of infinity() + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + return 1.0 # threshold has not yet been set. - return ans + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter( + tuples, tot_sumsq, group["scalar_lr_scale"] + ) + + if ans == 0.0: + for (p, state, param_names) in tuples: + p.grad.zero_() # get rid of infinity() + + return ans def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + tot_sumsq: Tensor, + scalar_lr_scale: float, ): """ Show information of parameter which dominates tot_sumsq. @@ -516,29 +530,30 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for p, state, batch_param_names in tuples: + for (p, state, batch_param_names) in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 # Dummy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) + batch_rms_orig = torch.full( + p.shape, scalar_lr_scale, device=batch_grad.device + ) else: batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2 + if batch_grad.ndim > 1: + # need to guard it with if-statement because sum() sums over + # all dims if dim == (). + batch_sumsq_orig = batch_sumsq_orig.sum( dim=list(range(1, batch_grad.ndim)) ) - for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): + proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) sorted_by_proportion = { k: v for k, v in sorted( @@ -552,7 +567,7 @@ class ScaledAdam(BatchedOptimizer): dominant_rms, dominant_grad, ) = sorted_by_proportion[dominant_param_name] - logging.info( + logging.warn( f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" @@ -826,7 +841,7 @@ class LRScheduler(object): def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: - logging.info( + logging.warn( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) @@ -841,8 +856,14 @@ class Eden(LRScheduler): where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches and then stays constant at 1. + If you don't have the concept of epochs, or one epoch takes a very long time, + you can replace the notion of 'epoch' with some measure of the amount of data + processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to + some measure representing "quite a lot of data": say, one fifth or one third + of an entire training run, but it doesn't matter much. You could also use + Eden2 which has only the notion of batches. - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam Args: optimizer: the optimizer to change the learning rates on @@ -888,6 +909,56 @@ class Eden(LRScheduler): return [x * factor * warmup_factor for x in self.base_lrs] +class Eden2(LRScheduler): + """ + Eden2 scheduler, simpler than Eden because it does not use the notion of epoch, + only batches. + + The basic formula (before warmup) is: + lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup + + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super().__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.5 + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + def _test_eden(): m = torch.nn.Linear(100, 100) optim = ScaledAdam(m.parameters(), lr=0.03) From c3bbb32f9ec6402f20582020eed64b159c55796f Mon Sep 17 00:00:00 2001 From: wnywbyt <45236066+wnywbyt@users.noreply.github.com> Date: Thu, 2 Nov 2023 20:45:30 +0800 Subject: [PATCH 043/216] Update the parameter 'vocab-size' (#1364) Co-authored-by: wdq --- egs/wenetspeech/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index f7eb9f0d0..b0525de60 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -362,6 +362,6 @@ if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then --exp-dir rnnlm_char/exp \ --lm-data data/lm_char/sorted_lm_data.pt \ --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ - --vocab-size 4336 \ + --vocab-size 5537 \ --master-port 12340 fi \ No newline at end of file From 231bbcd2b638826a94cf019fa31ae8683d3552ee Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 3 Nov 2023 12:06:29 +0800 Subject: [PATCH 044/216] Update optim.py (#1366) --- egs/librispeech/ASR/zipformer/optim.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index a663db708..714d8db9a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -466,6 +466,8 @@ class ScaledAdam(BatchedOptimizer): quartiles.append(sorted_norms[index].item()) median = quartiles[2] + if median - median != 0: + raise RuntimeError("Too many grads were not finite") threshold = clipping_scale * median if step in irregular_estimate_steps: # use larger thresholds on first few steps of estimating threshold, From 1b2e99d374cbbc527bf8c9239d616497249ccb1d Mon Sep 17 00:00:00 2001 From: lishaojie <95117087+manbaaaa@users.noreply.github.com> Date: Thu, 9 Nov 2023 22:07:28 +0800 Subject: [PATCH 045/216] add the pruned_transducer_stateless7_streaming recipe for commonvoice (#1018) * add the pruned_transducer_stateless7_streaming recipe for commonvoice * fix the symlinks * Update RESULTS.md --- egs/commonvoice/ASR/RESULTS.md | 25 + egs/commonvoice/ASR/local/compile_hlg.py | 1 + egs/commonvoice/ASR/local/compile_lg.py | 1 + .../compute_fbank_commonvoice_dev_test.py | 4 +- .../ASR/local/preprocess_commonvoice.py | 10 +- egs/commonvoice/ASR/prepare.sh | 64 +- .../README.md | 9 + .../beam_search.py | 1 + .../commonvoice_fr.py | 422 ++++++ .../decode.py | 810 ++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export-for-ncnn-zh.py | 1 + .../export-for-ncnn.py | 1 + .../export-onnx.py | 1 + .../export.py | 1 + .../finetune.py | 1342 +++++++++++++++++ .../generate_model_from_checkpoint.py | 281 ++++ .../jit_pretrained.py | 1 + .../jit_trace_export.py | 1 + .../jit_trace_pretrained.py | 1 + .../joiner.py | 1 + .../model.py | 1 + .../onnx_check.py | 1 + .../onnx_model_wrapper.py | 1 + .../onnx_pretrained.py | 1 + .../optim.py | 1 + .../pretrained.py | 1 + .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming-ncnn-decode.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 612 ++++++++ .../test_model.py | 150 ++ .../train.py | 1256 +++++++++++++++ .../train2.py | 1257 +++++++++++++++ .../zipformer.py | 1 + .../zipformer2.py | 1 + icefall/shared/convert-k2-to-openfst.py | 103 +- icefall/shared/ngram_entropy_pruning.py | 631 +------- icefall/shared/parse_options.sh | 98 +- 42 files changed, 6260 insertions(+), 840 deletions(-) create mode 120000 egs/commonvoice/ASR/local/compile_hlg.py create mode 120000 egs/commonvoice/ASR/local/compile_lg.py create mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py mode change 100755 => 120000 icefall/shared/convert-k2-to-openfst.py mode change 100755 => 120000 icefall/shared/ngram_entropy_pruning.py mode change 100755 => 120000 icefall/shared/parse_options.sh diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md index 751625371..2c158d91d 100644 --- a/egs/commonvoice/ASR/RESULTS.md +++ b/egs/commonvoice/ASR/RESULTS.md @@ -57,3 +57,28 @@ Pretrained model is available at The tensorboard log for training is available at + + +### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming) + +#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See #1018 for more details. + +Number of model parameters: 70369391, i.e., 70.37 M + +The best WER for Common Voice French 12.0 (cv-corpus-12.0-2022-12-07/fr) is below: + +Results are: + +| decoding method | Test | +|----------------------|-------| +| greedy search | 9.95 | +| modified beam search | 9.57 | +| fast beam search | 9.67 | + +Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice. + +Detailed experimental results and Pretrained model are available at + + diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/commonvoice/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/commonvoice/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py index c8f9b6ccb..a0b4d224c 100755 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py @@ -56,8 +56,8 @@ def get_args(): def compute_fbank_commonvoice_dev_test(language: str): src_dir = Path(f"data/{language}/manifests") output_dir = Path(f"data/{language}/fbank") - num_workers = 42 - batch_duration = 600 + num_workers = 16 + batch_duration = 200 subsets = ("dev", "test") diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index e60459765..5f6aa3ec0 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -43,9 +43,13 @@ def get_args(): return parser.parse_args() -def normalize_text(utt: str) -> str: +def normalize_text(utt: str, language: str) -> str: utt = re.sub(r"[{0}]+".format("-"), " ", utt) - return re.sub(r"[^a-zA-Z\s']", "", utt).upper() + utt = re.sub("’", "'", utt) + if language == "en": + return re.sub(r"[^a-zA-Z\s]", "", utt).upper() + if language == "fr": + return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper() def preprocess_commonvoice( @@ -94,7 +98,7 @@ def preprocess_commonvoice( for sup in m["supervisions"]: text = str(sup.text) orig_text = text - sup.text = normalize_text(sup.text) + sup.text = normalize_text(sup.text, language) text = str(sup.text) if len(orig_text) != len(text): logging.info( diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh index 3946908c6..edac0e8e6 100755 --- a/egs/commonvoice/ASR/prepare.sh +++ b/egs/commonvoice/ASR/prepare.sh @@ -36,8 +36,8 @@ num_splits=1000 # - speech dl_dir=$PWD/download -release=cv-corpus-13.0-2023-03-09 -lang=en +release=cv-corpus-12.0-2022-12-07 +lang=fr . shared/parse_options.sh || exit 1 @@ -146,7 +146,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then ./local/compute_fbank_commonvoice_splits.py \ --num-workers $nj \ - --batch-duration 600 \ + --batch-duration 200 \ --start 0 \ --num-splits $num_splits \ --language $lang @@ -189,7 +189,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then sed -i 's/\t/ /g' $lang_dir/transcript_words.txt sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt fi - + if [ ! -f $lang_dir/words.txt ]; then cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ | sort -u | sed '/^$/d' > $lang_dir/words.txt @@ -216,14 +216,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then }' > $lang_dir/words || exit 1; mv $lang_dir/words $lang_dir/words.txt fi - + if [ ! -f $lang_dir/bpe.model ]; then ./local/train_bpe_model.py \ --lang-dir $lang_dir \ --vocab-size $vocab_size \ --transcript $lang_dir/transcript_words.txt fi - + if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang_bpe.py --lang-dir $lang_dir @@ -250,3 +250,55 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then fi done fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + mkdir -p $lang_dir/lm + #3-gram used in building HLG, 4-gram used for LM rescoring + for ngram in 3 4; do + if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_words.txt \ + -lm $lang_dir/lm/${ngram}gram.arpa + fi + + if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt + fi + done + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Compile HLG" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Compile LG" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md new file mode 100644 index 000000000..991875aaa --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md @@ -0,0 +1,9 @@ +This recipe implements Streaming Zipformer-Transducer model. + +See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. + +[./emformer.py](./emformer.py) and [./train.py](./train.py) +are basically the same as +[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..d7349b0a3 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py new file mode 100644 index 000000000..cafa4111d --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py @@ -0,0 +1,422 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class CommonVoiceAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--language", + type=str, + default="fr", + help="""Language of Common Voice""", + ) + group.add_argument( + "--cv-manifest-dir", + type=Path, + default=Path("data/fr/fbank"), + help="Path to directory with CommonVoice train/dev/test cuts.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz" + ) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..30f7c1e77 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,810 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from commonvoice_fr import CommonVoiceAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += 30 + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, 30), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + # errs_info = ( + # params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + # ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{key}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + commonvoice = CommonVoiceAsrDataModule(args) + + test_cuts = commonvoice.test_cuts() + + test_dl = commonvoice.test_dataloaders(test_cuts) + + test_sets = "test-cv" + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_sets, + results_dict=results_dict, + ) + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..cb673b3eb --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py new file mode 120000 index 000000000..72e43c297 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 120000 index 000000000..3b36924ef --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 120000 index 000000000..57a0cd0a0 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 120000 index 000000000..2acafdc61 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py new file mode 100755 index 000000000..3a10c5d81 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -0,0 +1,1342 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from commonvoice_fr import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--do-finetune", type=str2bool, default=False) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="""Embedding dimension in the 2 blocks of zipformer encoder + layers, comma separated + """, + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers,\ + comma separated; not the same as embedding dimension. + """, + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="""Unmasked dimensions in the encoders, relates to augmentation + during training. Must be <= each of encoder_dims. Empirically, less + than 256 seems to make performance worse. + """, + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="""Path to the BPE model. + This should be the bpe model of the original model + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.005, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have + # different behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + else: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + train_cuts = commonvoice.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = commonvoice.dev_cuts() + valid_dl = commonvoice.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments( + parser + ) # you may replace this with your own dataset + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py new file mode 100755 index 000000000..3fd14aa47 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. + +(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. + +(3) use the original model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path +from typing import Dict, List + +import sentencepiece as spm +import torch +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model." + "If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + print("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py new file mode 120000 index 000000000..5d9c6ba00 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 120000 index 000000000..457131699 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 120000 index 000000000..2b8fa3cbb --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..e17d4f734 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 120000 index 000000000..28bf7bb82 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 120000 index 000000000..c8548d459 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 120000 index 000000000..ae4d9bb04 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 120000 index 000000000..9510b8fde --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 120000 index 000000000..92c3904af --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..2adf271c1 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..dbe65d0a7 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from commonvoice_fr import CommonVoiceAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + idx = 0 + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + audio: np.ndarray = cut.load_audio() + if audio.max() > 1 or audio.min() < -1: + audio = audio / max(abs(audio.max()), abs(audio.min())) + print(audio) + print(audio.max()) + print(audio.min()) + print(cut) + idx += 1 + print(idx) + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + commonvoice = CommonVoiceAsrDataModule(args) + test_cuts = commonvoice.test_cuts() + test_sets = "test-cv" + + results_dict = decode_dataset( + cuts=test_cuts, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_sets, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..5400df804 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..a9bc9c2a2 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1256 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from commonvoice_fr import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/fr/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + train_cuts = commonvoice.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = commonvoice.dev_cuts() + valid_dl = commonvoice.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..c09c9537c --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1257 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from commonvoice_fr import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + train_cuts = commonvoice.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = commonvoice.dev_cuts() + valid_dl = commonvoice.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 120000 index 000000000..12dbda888 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py deleted file mode 100755 index 29a2cd7f7..000000000 --- a/icefall/shared/convert-k2-to-openfst.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes as input an FST in k2 format and convert it -to an FST in OpenFST format. - -The generated FST is saved into a binary file and its type is -StdVectorFst. - -Usage examples: -(1) Convert an acceptor - - ./convert-k2-to-openfst.py in.pt binary.fst - -(2) Convert a transducer - - ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import kaldifst.utils -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--olabels", - type=str, - default=None, - help="""If not empty, the input FST is assumed to be a transducer - and we use its attribute specified by "olabels" as the output labels. - """, - ) - parser.add_argument( - "input_filename", - type=str, - help="Path to the input FST in k2 format", - ) - - parser.add_argument( - "output_filename", - type=str, - help="Path to the output FST in OpenFst format", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - logging.info(f"{vars(args)}") - - input_filename = args.input_filename - output_filename = args.output_filename - olabels = args.olabels - - if Path(output_filename).is_file(): - logging.info(f"{output_filename} already exists - skipping") - return - - assert Path(input_filename).is_file(), f"{input_filename} does not exist" - logging.info(f"Loading {input_filename}") - k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) - if olabels: - assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" - - p = Path(output_filename).parent - if not p.is_dir(): - logging.info(f"Creating {p}") - p.mkdir(parents=True) - - logging.info("Converting (May take some time if the input FST is large)") - fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) - logging.info(f"Saving to {output_filename}") - fst.write(output_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 120000 index 000000000..24efe5eae --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/convert-k2-to-openfst.py \ No newline at end of file diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py deleted file mode 100755 index b1ebee9ea..000000000 --- a/icefall/shared/ngram_entropy_pruning.py +++ /dev/null @@ -1,630 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# -# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -./ngram_entropy_pruning.py \ - -threshold 1e-8 \ - -lm download/lm/4gram.arpa \ - -write-lm download/lm/4gram_pruned_1e8.arpa - -This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. -This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' -in the same way as SRILM. -""" - - -import argparse -import gzip -import logging -import math -import re -from collections import OrderedDict, defaultdict -from enum import Enum, unique -from io import StringIO - -parser = argparse.ArgumentParser( - description=""" - Prune an n-gram language model based on the relative entropy - between the original and the pruned model, based on Andreas Stolcke's paper. - An n-gram entry is removed, if the removal causes (training set) perplexity - of the model to increase by less than threshold relative. - - The command takes an arpa file and a pruning threshold as input, - and outputs a pruned arpa file. - """ -) -parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") -parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") -parser.add_argument( - "-write-lm", type=str, default=None, help="Path to output arpa file after pruning" -) -parser.add_argument( - "-minorder", - type=int, - default=1, - help="The minorder parameter limits pruning to ngrams of that length and above.", -) -parser.add_argument( - "-encoding", type=str, default="utf-8", help="Encoding of the arpa file" -) -parser.add_argument( - "-verbose", - type=int, - default=2, - choices=[0, 1, 2, 3, 4, 5], - help="Verbose level, where 0 is most noisy; 5 is most silent", -) -args = parser.parse_args() - -default_encoding = args.encoding -logging.basicConfig( - format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", - level=args.verbose * 10, -) - - -class Context(dict): - """ - This class stores data for a context h. - It behaves like a python dict object, except that it has several - additional attributes. - """ - - def __init__(self): - super().__init__() - self.log_bo = None - - -class Arpa: - """ - This is a class that implement the data structure of an APRA LM. - It (as well as some other classes) is modified based on the library - by Stefan Fischer: - https://github.com/sfischer13/python-arpa - """ - - UNK = "" - SOS = "" - EOS = "" - FLOAT_NDIGITS = 7 - base = 10 - - @staticmethod - def _check_input(my_input): - if not my_input: - raise ValueError - elif isinstance(my_input, tuple): - return my_input - elif isinstance(my_input, list): - return tuple(my_input) - elif isinstance(my_input, str): - return tuple(my_input.strip().split(" ")) - else: - raise ValueError - - @staticmethod - def _check_word(input_word): - if not isinstance(input_word, str): - raise ValueError - if " " in input_word: - raise ValueError - - def _replace_unks(self, words): - return tuple((w if w in self else self._unk) for w in words) - - def __init__(self, path=None, encoding=None, unk=None): - self._counts = OrderedDict() - self._ngrams = ( - OrderedDict() - ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) - self._vocabulary = set() - if unk is None: - self._unk = self.UNK - - if path is not None: - self.loadf(path, encoding) - - def __contains__(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] - - def contains_word(self, word): - self._check_word(word) - return word in self._vocabulary - - def add_count(self, order, count): - self._counts[order] = count - self._ngrams[order - 1] = defaultdict(Context) - - def update_counts(self): - for order in range(1, self.order() + 1): - count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) - if count > 0: - self._counts[order] = count - - def add_entry(self, ngram, p, bo=None, order=None): - # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - - # Note that p and bo here are in fact in the log domain (self.base = 10) - h_context = self._ngrams[len(h)][h] - h_context[w] = p - if bo is not None: - self._ngrams[len(ngram)][ngram].log_bo = bo - - for word in ngram: - self._vocabulary.add(word) - - def counts(self): - return sorted(self._counts.items()) - - def order(self): - return max(self._counts.keys(), default=None) - - def vocabulary(self, sort=True): - if sort: - return sorted(self._vocabulary) - else: - return self._vocabulary - - def _entries(self, order): - return ( - self._entry(h, w) - for h, wlist in self._ngrams[order - 1].items() - for w in wlist - ) - - def _entry(self, h, w): - # return the entry for the ngram (h, w) - ngram = h + (w,) - log_p = self._ngrams[len(h)][h][w] - log_bo = self._log_bo(ngram) - if log_bo is not None: - return ( - round(log_p, self.FLOAT_NDIGITS), - ngram, - round(log_bo, self.FLOAT_NDIGITS), - ) - else: - return round(log_p, self.FLOAT_NDIGITS), ngram - - def _log_bo(self, ngram): - if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: - return self._ngrams[len(ngram)][ngram].log_bo - else: - return None - - def _log_p(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: - return self._ngrams[len(h)][h][w] - else: - return None - - def log_p_raw(self, ngram): - log_p = self._log_p(ngram) - if log_p is not None: - return log_p - else: - if len(ngram) == 1: - raise KeyError - else: - log_bo = self._log_bo(ngram[:-1]) - if log_bo is None: - log_bo = 0 - return log_bo + self.log_p_raw(ngram[1:]) - - def log_joint_prob(self, sequence): - # Compute the joint prob of the sequence based on the chain rule - # Note that sequence should be a tuple of strings - # - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 - - log_joint_p = 0 - seq = sequence - while len(seq) > 0: - log_joint_p += self.log_p_raw(seq) - seq = seq[:-1] - - # If we're computing the marginal probability of the unigram - # context we have to look up instead since the former - # has prob = 0. - if len(seq) == 1 and seq[0] == self.SOS: - seq = (self.EOS,) - - return log_joint_p - - def set_new_context(self, h): - old_context = self._ngrams[len(h)][h] - self._ngrams[len(h)][h] = Context() - return old_context - - def log_p(self, ngram): - words = self._check_input(ngram) - if self._unk: - words = self._replace_unks(words) - return self.log_p_raw(words) - - def log_s(self, sentence, sos=SOS, eos=EOS): - words = self._check_input(sentence) - if self._unk: - words = self._replace_unks(words) - if sos: - words = (sos,) + words - if eos: - words = words + (eos,) - result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) - if sos: - result = result - self.log_p_raw(words[:1]) - return result - - def p(self, ngram): - return self.base ** self.log_p(ngram) - - def s(self, sentence): - return self.base ** self.log_s(sentence) - - def write(self, fp): - fp.write("\n\\data\\\n") - for order, count in self.counts(): - fp.write("ngram {}={}\n".format(order, count)) - fp.write("\n") - for order, _ in self.counts(): - fp.write("\\{}-grams:\n".format(order)) - for e in self._entries(order): - prob = e[0] - ngram = " ".join(e[1]) - if len(e) == 2: - fp.write("{}\t{}\n".format(prob, ngram)) - elif len(e) == 3: - backoff = e[2] - fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) - else: - raise ValueError - fp.write("\n") - fp.write("\\end\\\n") - - -class ArpaParser: - """ - This is a class that implement a parser of an arpa file - """ - - @unique - class State(Enum): - DATA = 1 - COUNT = 2 - HEADER = 3 - ENTRY = 4 - - re_count = re.compile(r"^ngram (\d+)=(\d+)$") - re_header = re.compile(r"^\\(\d+)-grams:$") - re_entry = re.compile( - "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" - "\t" - "(\\S+( \\S+)*)" - "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" - ) - - def _parse(self, fp): - self._result = [] - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - for line in fp: - line = line.strip() - if self._state == self.State.DATA: - self._data(line) - elif self._state == self.State.COUNT: - self._count(line) - elif self._state == self.State.HEADER: - self._header(line) - elif self._state == self.State.ENTRY: - self._entry(line) - if self._state != self.State.DATA: - raise Exception(line) - return self._result - - def _data(self, line): - if line == "\\data\\": - self._state = self.State.COUNT - self._tmp_model = Arpa() - else: - pass # skip comment line - - def _count(self, line): - match = self.re_count.match(line) - if match: - order = match.group(1) - count = match.group(2) - self._tmp_model.add_count(int(order), int(count)) - elif not line: - self._state = self.State.HEADER # there are no counts - else: - raise Exception(line) - - def _header(self, line): - match = self.re_header.match(line) - if match: - self._state = self.State.ENTRY - self._tmp_order = int(match.group(1)) - elif line == "\\end\\": - self._result.append(self._tmp_model) - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - elif not line: - pass # skip empty line - else: - raise Exception(line) - - def _entry(self, line): - match = self.re_entry.match(line) - if match: - p = self._float_or_int(match.group(1)) - ngram = tuple(match.group(4).split(" ")) - bo_match = match.group(7) - bo = self._float_or_int(bo_match) if bo_match else None - self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) - elif not line: - self._state = self.State.HEADER # last entry - else: - raise Exception(line) - - @staticmethod - def _float_or_int(s): - f = float(s) - i = int(f) - if str(i) == s: # don't drop trailing ".0" - return i - else: - return f - - def load(self, fp): - """Deserialize fp (a file-like object) to a Python object.""" - return self._parse(fp) - - def loadf(self, path, encoding=None): - """Deserialize path (.arpa, .gz) to a Python object.""" - path = str(path) - if path.endswith(".gz"): - with gzip.open(path, mode="rt", encoding=encoding) as f: - return self.load(f) - else: - with open(path, mode="rt", encoding=encoding) as f: - return self.load(f) - - def loads(self, s): - """Deserialize s (a str) to a Python object.""" - with StringIO(s) as f: - return self.load(f) - - def dump(self, obj, fp): - """Serialize obj to fp (a file-like object) in ARPA format.""" - obj.write(fp) - - def dumpf(self, obj, path, encoding=None): - """Serialize obj to path in ARPA format (.arpa, .gz).""" - path = str(path) - if path.endswith(".gz"): - with gzip.open(path, mode="wt", encoding=encoding) as f: - return self.dump(obj, f) - else: - with open(path, mode="wt", encoding=encoding) as f: - self.dump(obj, f) - - def dumps(self, obj): - """Serialize obj to an ARPA formatted str.""" - with StringIO() as f: - self.dump(obj, f) - return f.getvalue() - - -def add_log_p(prev_log_sum, log_p, base): - return math.log(base**log_p + base**prev_log_sum, base) - - -def compute_numerator_denominator(lm, h): - log_sum_seen_h = -math.inf - log_sum_seen_h_lower = -math.inf - base = lm.base - for w, log_p in lm._ngrams[len(h)][h].items(): - log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) - - ngram = h + (w,) - log_p_lower = lm.log_p_raw(ngram[1:]) - log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) - - numerator = 1.0 - base**log_sum_seen_h - denominator = 1.0 - base**log_sum_seen_h_lower - return numerator, denominator - - -def prune(lm, threshold, minorder): - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 - - for i in range( - lm.order(), max(minorder - 1, 1), -1 - ): # i is the order of the ngram (h, w) - logging.info("processing %d-grams ..." % i) - count_pruned_ngrams = 0 - - h_dict = lm._ngrams[i - 1] - for h in list(h_dict.keys()): - # old backoff weight, BOW(h) - log_bow = lm._log_bo(h) - if log_bow is None: - log_bow = 0 - - # Compute numerator and denominator of the backoff weight, - # so that we can quickly compute the BOW adjustment due to - # leaving out one prob. - numerator, denominator = compute_numerator_denominator(lm, h) - - # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 - - # Compute the marginal probability of the context, P(h) - h_log_p = lm.log_joint_prob(h) - - all_pruned = True - pruned_w_set = set() - - for w, log_p in h_dict[h].items(): - ngram = h + (w,) - - # lower-order estimate for ngramProb, P(w|h') - backoff_prob = lm.log_p_raw(ngram[1:]) - - # Compute BOW after removing ngram, BOW'(h) - new_log_bow = math.log( - numerator + lm.base**log_p, lm.base - ) - math.log(denominator + lm.base**backoff_prob, lm.base) - - # Compute change in entropy due to removal of ngram - delta_prob = backoff_prob + new_log_bow - log_p - delta_entropy = -(lm.base**h_log_p) * ( - (lm.base**log_p) * delta_prob - + numerator * (new_log_bow - log_bow) - ) - - # compute relative change in model (training set) perplexity - perp_change = lm.base**delta_entropy - 1.0 - - pruned = threshold > 0 and perp_change < threshold - - # Make sure we don't prune ngrams whose backoff nodes are needed - if ( - pruned - and len(ngram) in lm._ngrams - and len(lm._ngrams[len(ngram)][ngram]) > 0 - ): - pruned = False - - logging.debug( - "CONTEXT " - + str(h) - + " WORD " - + w - + " CONTEXTPROB %f " % h_log_p - + " OLDPROB %f " % log_p - + " NEWPROB %f " % (backoff_prob + new_log_bow) - + " DELTA-H %f " % delta_entropy - + " DELTA-LOGP %f " % delta_prob - + " PPL-CHANGE %f " % perp_change - + " PRUNED " - + str(pruned) - ) - - if pruned: - pruned_w_set.add(w) - count_pruned_ngrams += 1 - else: - all_pruned = False - - # If we removed all ngrams for this context we can - # remove the context itself, but only if the present - # context is not a prefix to a longer one. - if all_pruned and len(pruned_w_set) == len(h_dict[h]): - del h_dict[ - h - ] # this context h is no longer needed, as its ngram prob is stored at its own context h' - elif len(pruned_w_set) > 0: - # The pruning for this context h is actually done here - old_context = lm.set_new_context(h) - - for w, p_w in old_context.items(): - if w not in pruned_w_set: - lm.add_entry( - h + (w,), p_w - ) # the entry hw is stored at the context h - - # We need to recompute the back-off weight, but - # this can only be done after completing the pruning - # of the lower-order ngrams. - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 - - logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) - - # recompute backoff weights - for i in range( - max(minorder - 1, 1) + 1, lm.order() + 1 - ): # be careful of this order: from low- to high-order - for h in lm._ngrams[i - 1]: - numerator, denominator = compute_numerator_denominator(lm, h) - new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) - lm._ngrams[len(h)][h].log_bo = new_log_bow - - # update counts - lm.update_counts() - - return - - -def check_h_is_valid(lm, h): - sum_under_h = sum( - [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] - ) - if abs(sum_under_h - 1.0) > 1e-6: - logging.info("warning: %s %f" % (str(h), sum_under_h)) - return False - else: - return True - - -def validate_lm(lm): - # sanity check if the conditional probability sums to one under each context h - for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) - logging.info("validating %d-grams ..." % i) - h_dict = lm._ngrams[i - 1] - for h in h_dict.keys(): - check_h_is_valid(lm, h) - - -def compare_two_apras(path1, path2): - pass - - -if __name__ == "__main__": - # load an arpa file - logging.info("Loading the arpa file from %s" % args.lm) - parser = ArpaParser() - models = parser.loadf(args.lm, encoding=default_encoding) - lm = models[0] # ARPA files may contain several models. - logging.info("Stats before pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - - # prune it, the language model will be modified in-place - logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) - prune(lm, args.threshold, args.minorder) - - # validate_lm(lm) - - # write the arpa language model to a file - logging.info("Stats after pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - logging.info("Saving the pruned arpa file to %s" % args.write_lm) - parser.dumpf(lm, args.write_lm, encoding=default_encoding) - logging.info("Done.") diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py new file mode 120000 index 000000000..0e14ac415 --- /dev/null +++ b/icefall/shared/ngram_entropy_pruning.py @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/ngram_entropy_pruning.py \ No newline at end of file diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh deleted file mode 100755 index 71fb9e5ea..000000000 --- a/icefall/shared/parse_options.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash - -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); -# Arnab Ghoshal, Karel Vesely - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# Parse command-line options. -# To be sourced by another script (as in ". parse_options.sh"). -# Option format is: --option-name arg -# and shell variable "option_name" gets set to value "arg." -# The exception is --help, which takes no arguments, but prints the -# $help_message variable (if defined). - - -### -### The --config file options have lower priority to command line -### options, so we need to import them first... -### - -# Now import all the configs specified by command-line, in left-to-right order -for ((argpos=1; argpos<$#; argpos++)); do - if [ "${!argpos}" == "--config" ]; then - argpos_plus1=$((argpos+1)) - config=${!argpos_plus1} - [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 - . $config # source the config file. - fi -done - - -### -### Now we process the command line options -### -while true; do - [ -z "${1:-}" ] && break; # break if there are no arguments - case "$1" in - # If the enclosing script is called with --help option, print the help - # message and exit. Scripts should put help messages in $help_message - --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; - else printf "$help_message\n" 1>&2 ; fi; - exit 0 ;; - --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" - exit 1 ;; - # If the first command-line argument begins with "--" (e.g. --foo-bar), - # then work out the variable name as $name, which will equal "foo_bar". - --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; - # Next we test whether the variable in question is undefned-- if so it's - # an invalid option and we die. Note: $0 evaluates to the name of the - # enclosing script. - # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar - # is undefined. We then have to wrap this test inside "eval" because - # foo_bar is itself inside a variable ($name). - eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; - - oldval="`eval echo \\$$name`"; - # Work out whether we seem to be expecting a Boolean argument. - if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then - was_bool=true; - else - was_bool=false; - fi - - # Set the variable to the right value-- the escaped quotes make it work if - # the option had spaces, like --cmd "queue.pl -sync y" - eval $name=\"$2\"; - - # Check that Boolean-valued arguments are really Boolean. - if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then - echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 - exit 1; - fi - shift 2; - ;; - *) break; - esac -done - - -# Check for an empty argument to the --cmd option, which can easily occur as a -# result of scripting errors. -[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; - - -true; # so this script returns exit code 0. diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh new file mode 120000 index 000000000..e4665e7de --- /dev/null +++ b/icefall/shared/parse_options.sh @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file From 6d275ddf9fdd67a32b79b93d70fedffe4b156d5c Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 10 Nov 2023 14:45:16 +0800 Subject: [PATCH 046/216] fixed broken softlinks (#1381) * removed broken softlinks * fixed dependencies * fixed file permission --- icefall/shared/convert-k2-to-openfst.py | 103 +++- icefall/shared/ngram_entropy_pruning.py | 631 +++++++++++++++++++++++- icefall/shared/parse_options.sh | 98 +++- 3 files changed, 829 insertions(+), 3 deletions(-) mode change 120000 => 100755 icefall/shared/convert-k2-to-openfst.py mode change 120000 => 100755 icefall/shared/ngram_entropy_pruning.py mode change 120000 => 100755 icefall/shared/parse_options.sh diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py deleted file mode 120000 index 24efe5eae..000000000 --- a/icefall/shared/convert-k2-to-openfst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/convert-k2-to-openfst.py \ No newline at end of file diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 100755 index 000000000..29a2cd7f7 --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes as input an FST in k2 format and convert it +to an FST in OpenFST format. + +The generated FST is saved into a binary file and its type is +StdVectorFst. + +Usage examples: +(1) Convert an acceptor + + ./convert-k2-to-openfst.py in.pt binary.fst + +(2) Convert a transducer + + ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst.utils +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--olabels", + type=str, + default=None, + help="""If not empty, the input FST is assumed to be a transducer + and we use its attribute specified by "olabels" as the output labels. + """, + ) + parser.add_argument( + "input_filename", + type=str, + help="Path to the input FST in k2 format", + ) + + parser.add_argument( + "output_filename", + type=str, + help="Path to the output FST in OpenFst format", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(f"{vars(args)}") + + input_filename = args.input_filename + output_filename = args.output_filename + olabels = args.olabels + + if Path(output_filename).is_file(): + logging.info(f"{output_filename} already exists - skipping") + return + + assert Path(input_filename).is_file(), f"{input_filename} does not exist" + logging.info(f"Loading {input_filename}") + k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) + if olabels: + assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" + + p = Path(output_filename).parent + if not p.is_dir(): + logging.info(f"Creating {p}") + p.mkdir(parents=True) + + logging.info("Converting (May take some time if the input FST is large)") + fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) + logging.info(f"Saving to {output_filename}") + fst.write(output_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py deleted file mode 120000 index 0e14ac415..000000000 --- a/icefall/shared/ngram_entropy_pruning.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/ngram_entropy_pruning.py \ No newline at end of file diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py new file mode 100755 index 000000000..b1ebee9ea --- /dev/null +++ b/icefall/shared/ngram_entropy_pruning.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./ngram_entropy_pruning.py \ + -threshold 1e-8 \ + -lm download/lm/4gram.arpa \ + -write-lm download/lm/4gram_pruned_1e8.arpa + +This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. +This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' +in the same way as SRILM. +""" + + +import argparse +import gzip +import logging +import math +import re +from collections import OrderedDict, defaultdict +from enum import Enum, unique +from io import StringIO + +parser = argparse.ArgumentParser( + description=""" + Prune an n-gram language model based on the relative entropy + between the original and the pruned model, based on Andreas Stolcke's paper. + An n-gram entry is removed, if the removal causes (training set) perplexity + of the model to increase by less than threshold relative. + + The command takes an arpa file and a pruning threshold as input, + and outputs a pruned arpa file. + """ +) +parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") +parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") +parser.add_argument( + "-write-lm", type=str, default=None, help="Path to output arpa file after pruning" +) +parser.add_argument( + "-minorder", + type=int, + default=1, + help="The minorder parameter limits pruning to ngrams of that length and above.", +) +parser.add_argument( + "-encoding", type=str, default="utf-8", help="Encoding of the arpa file" +) +parser.add_argument( + "-verbose", + type=int, + default=2, + choices=[0, 1, 2, 3, 4, 5], + help="Verbose level, where 0 is most noisy; 5 is most silent", +) +args = parser.parse_args() + +default_encoding = args.encoding +logging.basicConfig( + format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", + level=args.verbose * 10, +) + + +class Context(dict): + """ + This class stores data for a context h. + It behaves like a python dict object, except that it has several + additional attributes. + """ + + def __init__(self): + super().__init__() + self.log_bo = None + + +class Arpa: + """ + This is a class that implement the data structure of an APRA LM. + It (as well as some other classes) is modified based on the library + by Stefan Fischer: + https://github.com/sfischer13/python-arpa + """ + + UNK = "" + SOS = "" + EOS = "" + FLOAT_NDIGITS = 7 + base = 10 + + @staticmethod + def _check_input(my_input): + if not my_input: + raise ValueError + elif isinstance(my_input, tuple): + return my_input + elif isinstance(my_input, list): + return tuple(my_input) + elif isinstance(my_input, str): + return tuple(my_input.strip().split(" ")) + else: + raise ValueError + + @staticmethod + def _check_word(input_word): + if not isinstance(input_word, str): + raise ValueError + if " " in input_word: + raise ValueError + + def _replace_unks(self, words): + return tuple((w if w in self else self._unk) for w in words) + + def __init__(self, path=None, encoding=None, unk=None): + self._counts = OrderedDict() + self._ngrams = ( + OrderedDict() + ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) + self._vocabulary = set() + if unk is None: + self._unk = self.UNK + + if path is not None: + self.loadf(path, encoding) + + def __contains__(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] + + def contains_word(self, word): + self._check_word(word) + return word in self._vocabulary + + def add_count(self, order, count): + self._counts[order] = count + self._ngrams[order - 1] = defaultdict(Context) + + def update_counts(self): + for order in range(1, self.order() + 1): + count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) + if count > 0: + self._counts[order] = count + + def add_entry(self, ngram, p, bo=None, order=None): + # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + + # Note that p and bo here are in fact in the log domain (self.base = 10) + h_context = self._ngrams[len(h)][h] + h_context[w] = p + if bo is not None: + self._ngrams[len(ngram)][ngram].log_bo = bo + + for word in ngram: + self._vocabulary.add(word) + + def counts(self): + return sorted(self._counts.items()) + + def order(self): + return max(self._counts.keys(), default=None) + + def vocabulary(self, sort=True): + if sort: + return sorted(self._vocabulary) + else: + return self._vocabulary + + def _entries(self, order): + return ( + self._entry(h, w) + for h, wlist in self._ngrams[order - 1].items() + for w in wlist + ) + + def _entry(self, h, w): + # return the entry for the ngram (h, w) + ngram = h + (w,) + log_p = self._ngrams[len(h)][h][w] + log_bo = self._log_bo(ngram) + if log_bo is not None: + return ( + round(log_p, self.FLOAT_NDIGITS), + ngram, + round(log_bo, self.FLOAT_NDIGITS), + ) + else: + return round(log_p, self.FLOAT_NDIGITS), ngram + + def _log_bo(self, ngram): + if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: + return self._ngrams[len(ngram)][ngram].log_bo + else: + return None + + def _log_p(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: + return self._ngrams[len(h)][h][w] + else: + return None + + def log_p_raw(self, ngram): + log_p = self._log_p(ngram) + if log_p is not None: + return log_p + else: + if len(ngram) == 1: + raise KeyError + else: + log_bo = self._log_bo(ngram[:-1]) + if log_bo is None: + log_bo = 0 + return log_bo + self.log_p_raw(ngram[1:]) + + def log_joint_prob(self, sequence): + # Compute the joint prob of the sequence based on the chain rule + # Note that sequence should be a tuple of strings + # + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 + + log_joint_p = 0 + seq = sequence + while len(seq) > 0: + log_joint_p += self.log_p_raw(seq) + seq = seq[:-1] + + # If we're computing the marginal probability of the unigram + # context we have to look up instead since the former + # has prob = 0. + if len(seq) == 1 and seq[0] == self.SOS: + seq = (self.EOS,) + + return log_joint_p + + def set_new_context(self, h): + old_context = self._ngrams[len(h)][h] + self._ngrams[len(h)][h] = Context() + return old_context + + def log_p(self, ngram): + words = self._check_input(ngram) + if self._unk: + words = self._replace_unks(words) + return self.log_p_raw(words) + + def log_s(self, sentence, sos=SOS, eos=EOS): + words = self._check_input(sentence) + if self._unk: + words = self._replace_unks(words) + if sos: + words = (sos,) + words + if eos: + words = words + (eos,) + result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) + if sos: + result = result - self.log_p_raw(words[:1]) + return result + + def p(self, ngram): + return self.base ** self.log_p(ngram) + + def s(self, sentence): + return self.base ** self.log_s(sentence) + + def write(self, fp): + fp.write("\n\\data\\\n") + for order, count in self.counts(): + fp.write("ngram {}={}\n".format(order, count)) + fp.write("\n") + for order, _ in self.counts(): + fp.write("\\{}-grams:\n".format(order)) + for e in self._entries(order): + prob = e[0] + ngram = " ".join(e[1]) + if len(e) == 2: + fp.write("{}\t{}\n".format(prob, ngram)) + elif len(e) == 3: + backoff = e[2] + fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) + else: + raise ValueError + fp.write("\n") + fp.write("\\end\\\n") + + +class ArpaParser: + """ + This is a class that implement a parser of an arpa file + """ + + @unique + class State(Enum): + DATA = 1 + COUNT = 2 + HEADER = 3 + ENTRY = 4 + + re_count = re.compile(r"^ngram (\d+)=(\d+)$") + re_header = re.compile(r"^\\(\d+)-grams:$") + re_entry = re.compile( + "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" + "\t" + "(\\S+( \\S+)*)" + "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" + ) + + def _parse(self, fp): + self._result = [] + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + for line in fp: + line = line.strip() + if self._state == self.State.DATA: + self._data(line) + elif self._state == self.State.COUNT: + self._count(line) + elif self._state == self.State.HEADER: + self._header(line) + elif self._state == self.State.ENTRY: + self._entry(line) + if self._state != self.State.DATA: + raise Exception(line) + return self._result + + def _data(self, line): + if line == "\\data\\": + self._state = self.State.COUNT + self._tmp_model = Arpa() + else: + pass # skip comment line + + def _count(self, line): + match = self.re_count.match(line) + if match: + order = match.group(1) + count = match.group(2) + self._tmp_model.add_count(int(order), int(count)) + elif not line: + self._state = self.State.HEADER # there are no counts + else: + raise Exception(line) + + def _header(self, line): + match = self.re_header.match(line) + if match: + self._state = self.State.ENTRY + self._tmp_order = int(match.group(1)) + elif line == "\\end\\": + self._result.append(self._tmp_model) + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + elif not line: + pass # skip empty line + else: + raise Exception(line) + + def _entry(self, line): + match = self.re_entry.match(line) + if match: + p = self._float_or_int(match.group(1)) + ngram = tuple(match.group(4).split(" ")) + bo_match = match.group(7) + bo = self._float_or_int(bo_match) if bo_match else None + self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) + elif not line: + self._state = self.State.HEADER # last entry + else: + raise Exception(line) + + @staticmethod + def _float_or_int(s): + f = float(s) + i = int(f) + if str(i) == s: # don't drop trailing ".0" + return i + else: + return f + + def load(self, fp): + """Deserialize fp (a file-like object) to a Python object.""" + return self._parse(fp) + + def loadf(self, path, encoding=None): + """Deserialize path (.arpa, .gz) to a Python object.""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + else: + with open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + + def loads(self, s): + """Deserialize s (a str) to a Python object.""" + with StringIO(s) as f: + return self.load(f) + + def dump(self, obj, fp): + """Serialize obj to fp (a file-like object) in ARPA format.""" + obj.write(fp) + + def dumpf(self, obj, path, encoding=None): + """Serialize obj to path in ARPA format (.arpa, .gz).""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="wt", encoding=encoding) as f: + return self.dump(obj, f) + else: + with open(path, mode="wt", encoding=encoding) as f: + self.dump(obj, f) + + def dumps(self, obj): + """Serialize obj to an ARPA formatted str.""" + with StringIO() as f: + self.dump(obj, f) + return f.getvalue() + + +def add_log_p(prev_log_sum, log_p, base): + return math.log(base**log_p + base**prev_log_sum, base) + + +def compute_numerator_denominator(lm, h): + log_sum_seen_h = -math.inf + log_sum_seen_h_lower = -math.inf + base = lm.base + for w, log_p in lm._ngrams[len(h)][h].items(): + log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) + + ngram = h + (w,) + log_p_lower = lm.log_p_raw(ngram[1:]) + log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) + + numerator = 1.0 - base**log_sum_seen_h + denominator = 1.0 - base**log_sum_seen_h_lower + return numerator, denominator + + +def prune(lm, threshold, minorder): + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 + + for i in range( + lm.order(), max(minorder - 1, 1), -1 + ): # i is the order of the ngram (h, w) + logging.info("processing %d-grams ..." % i) + count_pruned_ngrams = 0 + + h_dict = lm._ngrams[i - 1] + for h in list(h_dict.keys()): + # old backoff weight, BOW(h) + log_bow = lm._log_bo(h) + if log_bow is None: + log_bow = 0 + + # Compute numerator and denominator of the backoff weight, + # so that we can quickly compute the BOW adjustment due to + # leaving out one prob. + numerator, denominator = compute_numerator_denominator(lm, h) + + # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 + + # Compute the marginal probability of the context, P(h) + h_log_p = lm.log_joint_prob(h) + + all_pruned = True + pruned_w_set = set() + + for w, log_p in h_dict[h].items(): + ngram = h + (w,) + + # lower-order estimate for ngramProb, P(w|h') + backoff_prob = lm.log_p_raw(ngram[1:]) + + # Compute BOW after removing ngram, BOW'(h) + new_log_bow = math.log( + numerator + lm.base**log_p, lm.base + ) - math.log(denominator + lm.base**backoff_prob, lm.base) + + # Compute change in entropy due to removal of ngram + delta_prob = backoff_prob + new_log_bow - log_p + delta_entropy = -(lm.base**h_log_p) * ( + (lm.base**log_p) * delta_prob + + numerator * (new_log_bow - log_bow) + ) + + # compute relative change in model (training set) perplexity + perp_change = lm.base**delta_entropy - 1.0 + + pruned = threshold > 0 and perp_change < threshold + + # Make sure we don't prune ngrams whose backoff nodes are needed + if ( + pruned + and len(ngram) in lm._ngrams + and len(lm._ngrams[len(ngram)][ngram]) > 0 + ): + pruned = False + + logging.debug( + "CONTEXT " + + str(h) + + " WORD " + + w + + " CONTEXTPROB %f " % h_log_p + + " OLDPROB %f " % log_p + + " NEWPROB %f " % (backoff_prob + new_log_bow) + + " DELTA-H %f " % delta_entropy + + " DELTA-LOGP %f " % delta_prob + + " PPL-CHANGE %f " % perp_change + + " PRUNED " + + str(pruned) + ) + + if pruned: + pruned_w_set.add(w) + count_pruned_ngrams += 1 + else: + all_pruned = False + + # If we removed all ngrams for this context we can + # remove the context itself, but only if the present + # context is not a prefix to a longer one. + if all_pruned and len(pruned_w_set) == len(h_dict[h]): + del h_dict[ + h + ] # this context h is no longer needed, as its ngram prob is stored at its own context h' + elif len(pruned_w_set) > 0: + # The pruning for this context h is actually done here + old_context = lm.set_new_context(h) + + for w, p_w in old_context.items(): + if w not in pruned_w_set: + lm.add_entry( + h + (w,), p_w + ) # the entry hw is stored at the context h + + # We need to recompute the back-off weight, but + # this can only be done after completing the pruning + # of the lower-order ngrams. + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 + + logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) + + # recompute backoff weights + for i in range( + max(minorder - 1, 1) + 1, lm.order() + 1 + ): # be careful of this order: from low- to high-order + for h in lm._ngrams[i - 1]: + numerator, denominator = compute_numerator_denominator(lm, h) + new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) + lm._ngrams[len(h)][h].log_bo = new_log_bow + + # update counts + lm.update_counts() + + return + + +def check_h_is_valid(lm, h): + sum_under_h = sum( + [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] + ) + if abs(sum_under_h - 1.0) > 1e-6: + logging.info("warning: %s %f" % (str(h), sum_under_h)) + return False + else: + return True + + +def validate_lm(lm): + # sanity check if the conditional probability sums to one under each context h + for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) + logging.info("validating %d-grams ..." % i) + h_dict = lm._ngrams[i - 1] + for h in h_dict.keys(): + check_h_is_valid(lm, h) + + +def compare_two_apras(path1, path2): + pass + + +if __name__ == "__main__": + # load an arpa file + logging.info("Loading the arpa file from %s" % args.lm) + parser = ArpaParser() + models = parser.loadf(args.lm, encoding=default_encoding) + lm = models[0] # ARPA files may contain several models. + logging.info("Stats before pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + + # prune it, the language model will be modified in-place + logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) + prune(lm, args.threshold, args.minorder) + + # validate_lm(lm) + + # write the arpa language model to a file + logging.info("Stats after pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + logging.info("Saving the pruned arpa file to %s" % args.write_lm) + parser.dumpf(lm, args.write_lm, encoding=default_encoding) + logging.info("Done.") diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh deleted file mode 120000 index e4665e7de..000000000 --- a/icefall/shared/parse_options.sh +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh new file mode 100755 index 000000000..71fb9e5ea --- /dev/null +++ b/icefall/shared/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. From 59c943878ff7f3d741a29d743b8560e342fa892d Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Thu, 16 Nov 2023 07:38:31 +0100 Subject: [PATCH 047/216] add the `voxpopuli` recipe (#1374) * add the `voxpopuli` recipe - this is the data preparation - there is no ASR training and no results * update the PR#1374 (feedback from @csukuangfj) - fixing .py headers and docstrings - removing BUT specific parts of `prepare.sh` - adding assert `num_jobs >= num_workers` to `compute_fbank.py` - narrowing list of languages (let's limit to ASR sets with transcripts for now) - added links to `README.md` - extending `text_from_manifest.py` --- egs/voxpopuli/ASR/README.md | 38 +++ egs/voxpopuli/ASR/local/compute_fbank.py | 248 +++++++++++++++++ .../ASR/local/compute_fbank_musan.py | 1 + .../ASR/local/display_manifest_statistics.py | 56 ++++ .../duration_from_supervision_manifest.py | 93 +++++++ egs/voxpopuli/ASR/local/filter_cuts.py | 1 + egs/voxpopuli/ASR/local/prepare_lang_bpe.py | 1 + .../ASR/local/preprocess_voxpopuli.py | 178 ++++++++++++ .../ASR/local/separate_punctuation.py | 130 +++++++++ egs/voxpopuli/ASR/local/text_from_manifest.py | 54 ++++ egs/voxpopuli/ASR/local/train_bpe_model.py | 1 + .../ASR/local/uppercase_begin_of_sentence.py | 113 ++++++++ .../ASR/local/validate_bpe_lexicon.py | 1 + .../ASR/local/validate_cutset_manifest.py | 123 +++++++++ egs/voxpopuli/ASR/prepare.sh | 257 ++++++++++++++++++ egs/voxpopuli/ASR/shared | 1 + 16 files changed, 1296 insertions(+) create mode 100644 egs/voxpopuli/ASR/README.md create mode 100755 egs/voxpopuli/ASR/local/compute_fbank.py create mode 120000 egs/voxpopuli/ASR/local/compute_fbank_musan.py create mode 100755 egs/voxpopuli/ASR/local/display_manifest_statistics.py create mode 100755 egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py create mode 120000 egs/voxpopuli/ASR/local/filter_cuts.py create mode 120000 egs/voxpopuli/ASR/local/prepare_lang_bpe.py create mode 100755 egs/voxpopuli/ASR/local/preprocess_voxpopuli.py create mode 100755 egs/voxpopuli/ASR/local/separate_punctuation.py create mode 100755 egs/voxpopuli/ASR/local/text_from_manifest.py create mode 120000 egs/voxpopuli/ASR/local/train_bpe_model.py create mode 100755 egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py create mode 120000 egs/voxpopuli/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/voxpopuli/ASR/local/validate_cutset_manifest.py create mode 100755 egs/voxpopuli/ASR/prepare.sh create mode 120000 egs/voxpopuli/ASR/shared diff --git a/egs/voxpopuli/ASR/README.md b/egs/voxpopuli/ASR/README.md new file mode 100644 index 000000000..92aa26464 --- /dev/null +++ b/egs/voxpopuli/ASR/README.md @@ -0,0 +1,38 @@ +# Readme + +This recipe contains data preparation for the +[VoxPopuli](https://github.com/facebookresearch/voxpopuli) dataset +[(pdf)](https://aclanthology.org/2021.acl-long.80.pdf). +At the moment, without model training. + + +## audio per language + +| language | Size | Hrs. untranscribed | Hrs. transcribed | +|----------|--------|--------------------|------------------| +| bg | 295G | 17.6K | - | +| cs | 308G | 18.7K | 62 | +| da | 233G | 13.6K | - | +| de | 379G | 23.2K | 282 | +| el | 305G | 17.7K | - | +| en | 382G | 24.1K | 543 | +| es | 362G | 21.4K | 166 | +| et | 179G | 10.6K | 3 | +| fi | 236G | 14.2K | 27 | +| fr | 376G | 22.8K | 211 | +| hr | 132G | 8.1K | 43 | +| hu | 297G | 17.7K | 63 | +| it | 361G | 21.9K | 91 | +| lt | 243G | 14.4K | 2 | +| lv | 217G | 13.1K | - | +| mt | 147G | 9.1K | - | +| nl | 322G | 19.0K | 53 | +| pl | 348G | 21.2K | 111 | +| pt | 300G | 17.5K | - | +| ro | 296G | 17.9K | 89 | +| sk | 201G | 12.1K | 35 | +| sl | 190G | 11.3K | 10 | +| sv | 272G | 16.3K | - | +| | | | | +| total | 6.3T | 384K | 1791 | + diff --git a/egs/voxpopuli/ASR/local/compute_fbank.py b/egs/voxpopuli/ASR/local/compute_fbank.py new file mode 100755 index 000000000..b63e51f29 --- /dev/null +++ b/egs/voxpopuli/ASR/local/compute_fbank.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of VoxPopuli dataset. + +Usage example: + + python3 ./local/compute_fbank.py \ + --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 100 --num-workers 25 \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset train \ + --trim-to-supervisions True \ + --speed-perturb True + +It looks for raw CutSet in the directory data/fbank +located at: `{src_dir}/{prefix}_cuts_{dataset}_raw.jsonl.gz`. + +The generated fbank features are saved in `data/fbank/{prefix}-{dataset}_feats` +and CutSet manifest stored in `data/fbank/{prefix}_cuts_{dataset}.jsonl.gz`. + +Typically, the number of workers is smaller than number of jobs +(see --num-jobs 100 --num-workers 25 in the example). +And, the number of jobs should be at least the number of workers (it's checked). +""" + +import argparse +import logging +import multiprocessing +import os +from concurrent.futures import ProcessPoolExecutor +from pathlib import Path + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + is_caching_enabled, + set_caching_enabled, +) + +from icefall.utils import str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + parser.add_argument( + "--src-dir", + type=str, + help="""Folder with the input manifest files.""", + default="data/manifests", + ) + parser.add_argument( + "--output-dir", + type=str, + help="""Folder with the output manifests (cuts) and feature files.""", + default="data/fbank", + ) + + parser.add_argument( + "--prefix", + type=str, + help="""Prefix of the manifest files.""", + default="", + ) + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank (train,test,dev).""", + default=None, + ) + + parser.add_argument( + "--num-jobs", + type=int, + help="""Number of jobs (i.e. files with extracted features)""", + default=50, + ) + parser.add_argument( + "--num-workers", + type=int, + help="""Number of parallel workers""", + default=10, + ) + parser.add_argument( + "--speed-perturb", + type=str2bool, + default=False, + help="""Enable speed perturbation for the set.""", + ) + parser.add_argument( + "--trim-to-supervisions", + type=str2bool, + default=False, + help="""Apply `trim-to-supervision` to cut set.""", + ) + + return parser.parse_args() + + +def compute_fbank_features(args: argparse.Namespace): + set_caching_enabled(True) # lhotse + + src_dir = Path(args.src_dir) + output_dir = Path(args.output_dir) + num_jobs = args.num_jobs + num_workers = min(args.num_workers, os.cpu_count()) + num_mel_bins = 80 + + bpe_model = args.bpe_model + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + prefix = args.prefix # "ELEF_TRAIN" + dataset = args.dataset + suffix = "jsonl.gz" + + cuts_raw_filename = Path(f"{src_dir}/{prefix}_cuts_{dataset}_raw.{suffix}") + cuts_raw = CutSet.from_file(cuts_raw_filename) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + cuts_filename = Path(f"{prefix}_cuts_{dataset}.{suffix}") + if (output_dir / cuts_filename).is_file(): + logging.info(f"{output_dir/cuts_filename} already exists - skipping.") + return + + logging.info(f"Processing {output_dir/cuts_filename}") + cut_set = cuts_raw + + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + + if args.speed_perturb: + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + + if args.trim_to_supervisions: + logging.info(f"About to `trim_to_supervisions()` {output_dir / cuts_filename}") + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + else: + logging.info( + "Not doing `trim_to_supervisions()`, " + "to enable use --trim-to-supervision=True" + ) + + cut_set = cut_set.to_eager() # disallow lazy evaluation (sorting requires it) + cut_set = cut_set.sort_by_recording_id() # enhances AudioCache hit rate + + # We typically use `num_jobs=100, num_workers=20` + # - this is helpful for large databases + # - both values are configurable externally + assert num_jobs >= num_workers, (num_jobs, num_workers) + executor = ProcessPoolExecutor( + max_workers=num_workers, + mp_context=multiprocessing.get_context("spawn"), + initializer=set_caching_enabled, + initargs=(is_caching_enabled(),), + ) + + logging.info( + f"executor {executor} : num_workers {num_workers}, num_jobs {num_jobs}" + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir / prefix}-{dataset}_feats", + num_jobs=num_jobs, + executor=executor, + storage_type=LilcomChunkyWriter, + ) + + # correct small deviations of duration, caused by speed-perturbation + for cut in cut_set: + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut.id) + duration_difference = abs(cut.supervisions[0].duration - cut.duration) + tolerance = 0.02 # 20ms + if duration_difference == 0.0: + pass + elif duration_difference <= tolerance: + logging.info( + "small mismatch of the supervision duration " + f"(Δt = {duration_difference*1000}ms), " + f"correcting : cut.duration {cut.duration} -> " + f"supervision {cut.supervisions[0].duration}" + ) + cut.supervisions[0].duration = cut.duration + else: + logging.error( + "mismatch of cut/supervision duration " + f"(Δt = {duration_difference*1000}ms) : " + f"cut.duration {cut.duration}, " + f"supervision {cut.supervisions[0].duration}" + ) + raise ValueError( + "mismatch of cut/supervision duration " + f"(Δt = {duration_difference*1000}ms)" + ) + + # store the cutset + logging.info(f"storing CutSet to : `{output_dir / cuts_filename}`") + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + logging.info(vars(args)) + + compute_fbank_features(args) diff --git a/egs/voxpopuli/ASR/local/compute_fbank_musan.py b/egs/voxpopuli/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/voxpopuli/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/display_manifest_statistics.py b/egs/voxpopuli/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..36c99e126 --- /dev/null +++ b/egs/voxpopuli/ASR/local/display_manifest_statistics.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +Usage example: + python3 ./local/display_manifest_statistics.py data/fbank/*_cuts*.jsonl.gz + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. + +""" + +import argparse + +from lhotse import load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser("Compute statistics for 'cuts' .jsonl.gz") + + parser.add_argument( + "filename", + help="data/fbank/imported_cuts_bison-train_trim.jsonl.gz", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + cuts = load_manifest_lazy(args.filename) + cuts.describe() + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py new file mode 100755 index 000000000..957267fe8 --- /dev/null +++ b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script computes durations of datasets from +the SupervisionSet manifests. + +Usage example: + + python3 ./local/duration_from_supervision_manifest.py \ + data/manifest/*_superivions*.jsonl.gz +""" + +import argparse +import gzip +import json +import logging +import re +import sys + + +def get_args(): + parser = argparse.ArgumentParser( + "Read the raw text from the 'supervisions.jsonl.gz'" + ) + + parser.add_argument( + "filename", + help="supervisions.jsonl.gz", + nargs="+", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(vars(args)) + + total_duration = 0.0 + total_n_utts = 0 + + for fname in args.filename: + if fname == "-": + fd = sys.stdin + elif re.match(r".*\.jsonl\.gz$", fname): + fd = gzip.open(fname, mode="r") + else: + fd = open(fname, mode="r") + + fname_duration = 0.0 + n_utts = 0 + for line in fd: + js = json.loads(line) + fname_duration += js["duration"] + n_utts += 1 + + print( + f"Duration: {fname_duration/3600:7.2f} hours " + f"(eq. {fname_duration:7.0f} seconds, {n_utts} utts): {fname}" + ) + + if fd != sys.stdin: + fd.close() + + total_duration += fname_duration + total_n_utts += n_utts + + print( + f"Total duration: {total_duration/3600:7.2f} hours " + f"(eq. {total_duration:7.0f} seconds)" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/voxpopuli/ASR/local/filter_cuts.py b/egs/voxpopuli/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/voxpopuli/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/prepare_lang_bpe.py b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py new file mode 100755 index 000000000..4032537db --- /dev/null +++ b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# 2023 Brno University of Technology (author: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preprocess the database. +- Convert RecordingSet and SupervisionSet to CutSet. +- Apply text normalization to the transcripts. + - We take renormalized `orig_text` as `text` transcripts. + - The text normalization is separating punctuation from words. + - Also we put capital letter to the beginning of a sentence. + +The script is inspired in: + `egs/commonvoice/ASR/local/preprocess_commonvoice.py` + +Usage example: + python3 ./local/preprocess_voxpopuli.py \ + --task asr --lang en + +""" + +import argparse +import logging +from pathlib import Path +from typing import Optional + +from lhotse import CutSet +from lhotse.recipes.utils import read_manifests_if_cached + +# from local/ +from separate_punctuation import separate_punctuation +from uppercase_begin_of_sentence import UpperCaseBeginOfSentence + +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + default=None, + ) + + parser.add_argument( + "--task", + type=str, + help="""Task of VoxPopuli""", + default="asr", + ) + + parser.add_argument( + "--lang", + type=str, + help="""Language of VoxPopuli""", + required=True, + ) + + parser.add_argument( + "--use-original-text", + type=str2bool, + help="""Use 'original_text' from the annoattaion file, + otherwise 'normed_text' will be used + (see `data/manifests/${task}_${lang}.tsv.gz`). + """, + default=False, + ) + + return parser.parse_args() + + +def normalize_text(utt: str) -> str: + utt = UpperCaseBeginOfSentence().process_line_text(separate_punctuation(utt)) + return utt + + +def preprocess_voxpopuli( + task: str, + language: str, + dataset: Optional[str] = None, + use_original_text: bool = False, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + output_dir.mkdir(exist_ok=True) + + if dataset is None: + dataset_parts = ( + "dev", + "test", + "train", + ) + else: + dataset_parts = dataset.split(" ", -1) + + logging.info("Loading manifest") + prefix = f"voxpopuli-{task}-{language}" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, + prefix=prefix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + if use_original_text: + logging.info("Using 'original_text' from the annotation file.") + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + # `orig_text` includes punctuation and true-case + orig_text = str(sup.custom["orig_text"]) + # we replace `text` by normalized `orig_text` + sup.text = normalize_text(orig_text) + else: + logging.info("Using 'normed_text' from the annotation file.") + + # remove supervisions with empty 'text' + m["supervisions"] = m["supervisions"].filter(lambda sup: len(sup.text) > 0) + + # Create cut manifest with long-recordings. + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ).resample(16000) + + # Store the cut set incl. the resampling. + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + preprocess_voxpopuli( + task=args.task, + language=args.lang, + dataset=args.dataset, + use_original_text=args.use_original_text, + ) + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/separate_punctuation.py b/egs/voxpopuli/ASR/local/separate_punctuation.py new file mode 100755 index 000000000..706d6fcd5 --- /dev/null +++ b/egs/voxpopuli/ASR/local/separate_punctuation.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script chops the punctuation as standalone tokens. +Example: + input: "This is fine. Yes, you are right." + output: "This is fine . Yes , you are right ." + +The script also handles exceptions in a hard-coded fashion. + +(same functionality could be done with `nltk.tokenize.word_tokenize()`, + but that would be an extra dependency) + +It can be used as a module, or as an executable script. + +Usage example #1: + `from separate_punctuation import separate_punctuation` + +Usage example #2: +``` + python3 ./local/separate_punctuation.py \ + --ignore-columns 1 \ + < ${kaldi_data}/text +``` +""" + +import re +import sys +from argparse import ArgumentParser + + +def separate_punctuation(text: str) -> str: + """ + Text filtering function for separating punctuation. + + Example: + input: "This is fine. Yes, you are right." + output: "This is fine . Yes , you are right ." + + The exceptions for which the punctuation is + not splitted are hard-coded. + """ + + # remove non-desired punctuation symbols + text = re.sub('["„“«»]', "", text) + + # separate [,.!?;] punctuation from words by space + text = re.sub(r"(\w)([,.!?;])", r"\1 \2", text) + text = re.sub(r"([,.!?;])(\w)", r"\1 \2", text) + + # split to tokens + tokens = text.split() + tokens_out = [] + + # re-join the special cases of punctuation + for ii, tok in enumerate(tokens): + # no rewriting for 1st and last token + if ii > 0 and ii < len(tokens) - 1: + # **RULES ADDED FOR CZECH COMMON VOICE** + + # fix "27 . dubna" -> "27. dubna", but keep punctuation separate, + if tok == "." and tokens[ii - 1].isdigit() and tokens[ii + 1].islower(): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # fix "resp . pak" -> "resp. pak" + if tok == "." and tokens[ii - 1].isalpha() and tokens[ii + 1].islower(): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # **RULES ADDED FOR ENGLISH COMMON VOICE** + + # fix "A ." -> "A." + if tok == "." and re.match(r"^[A-Z]S", tokens[ii - 1]): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # fix "Mr ." -> "Mr." + exceptions = set(["Mr", "Mrs", "Ms"]) + if tok == "." and tokens[ii - 1] in exceptions: + tokens_out[-1] = tokens_out[-1] + "." + continue + + tokens_out.append(tok) + + return " ".join(tokens_out) + + +def get_args(): + parser = ArgumentParser( + description="Separate punctuation from words: 'hello.' -> 'hello .'" + ) + parser.add_argument( + "--ignore-columns", type=int, default=1, help="skip number of initial columns" + ) + return parser.parse_args() + + +def main(): + args = get_args() + + max_split = args.ignore_columns + + while True: + line = sys.stdin.readline() + if not line: + break + + *key, text = line.strip().split(maxsplit=max_split) + text_norm = separate_punctuation(text) + + print(" ".join(key), text_norm) + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/text_from_manifest.py b/egs/voxpopuli/ASR/local/text_from_manifest.py new file mode 100755 index 000000000..d9ab53b5a --- /dev/null +++ b/egs/voxpopuli/ASR/local/text_from_manifest.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Print the text contained in `supervisions.jsonl.gz` or `cuts.jsonl.gz`. + +Usage example: + python3 ./local/text_from_manifest.py \ + data/manifests/voxpopuli-asr-en_supervisions_dev.jsonl.gz +""" + +import argparse +import gzip +import json + + +def get_args(): + parser = argparse.ArgumentParser( + "Read the raw text from the 'supervisions.jsonl.gz'" + ) + parser.add_argument("filename", help="supervisions.jsonl.gz") + return parser.parse_args() + + +def main(): + args = get_args() + + with gzip.open(args.filename, mode="r") as fd: + for line in fd: + js = json.loads(line) + if "text" in js: + print(js["text"]) # supervisions.jsonl.gz + elif "supervisions" in js: + for s in js["supervisions"]: + print(s["text"]) # cuts.jsonl.gz + else: + raise Exception(f"Unknown jsonl format of {args.filename}") + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/train_bpe_model.py b/egs/voxpopuli/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/voxpopuli/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py new file mode 100755 index 000000000..8e9de905f --- /dev/null +++ b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script introduces initial capital letter at the beginning of a sentence. +It can be used as a module, or as an executable script. + +Usage example #1: + `from uppercase_begin_of_sentence import UpperCaseBeginOfSentence` + +Usage example #2: +``` + python3 ./local/uppercase_begin_of_sentence.py \ + --ignore-columns 1 \ + < ${kaldi_data}/text +``` +""" + +import re +import sys +from argparse import ArgumentParser + + +class UpperCaseBeginOfSentence: + """ + This class introduces initial capital letter at the beginning of a sentence. + Capital letter is used, if previous symbol was punctuation token from + `set([".", "!", "?"])`. + + The punctuation as previous token is memorized also across + `process_line_text()` calls. + """ + + def __init__(self): + # The 1st word will have Title-case + # This variable transfers context from previous line + self.prev_token_is_punct = True + + def process_line_text(self, line_text: str) -> str: + """ + It is assumed that punctuation in `line_text` was already separated, + example: "This is fine . Yes , you are right ." + """ + + words = line_text.split() + punct_set = set([".", "!", "?"]) + + for ii, w in enumerate(words): + # punctuation ? + if w in punct_set: + self.prev_token_is_punct = True + continue + + # change case of word... + if self.prev_token_is_punct: + if re.match("<", w): + continue # skip + # apply Title-case only on lowercase words. + if w.islower(): + words[ii] = w.title() + # change state + self.prev_token_is_punct = False + + line_text_uc = " ".join(words) + + return line_text_uc + + +def get_args(): + parser = ArgumentParser( + description="Put upper-case at the beginning of a sentence." + ) + parser.add_argument( + "--ignore-columns", type=int, default=4, help="skip number of initial columns" + ) + return parser.parse_args() + + +def main(): + args = get_args() + + uc_bos = UpperCaseBeginOfSentence() + max_split = args.ignore_columns + + while True: + line = sys.stdin.readline() + if not line: + break + line = line.strip() + + if len(line.split()) > 1: + *key, text = line.strip().split(maxsplit=max_split) # parse, + text_uc = uc_bos.process_line_text(text) # process, + print(" ".join(key), text_uc) # print, + else: + print(line) + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/validate_cutset_manifest.py b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py new file mode 100755 index 000000000..4659aa9cd --- /dev/null +++ b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within Cut time bounds +- Duration of Cut and Superivion are equal + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz + +(Based on: `librispeech/ASR/local/validate_manifest.py`) +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.dataset.speech_recognition import validate_for_asr + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "cutset_manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + tol = 2e-3 # same tolerance as in 'validate_for_asr()' + s = c.supervisions[0] + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + if s.start < -tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} must not be negative." + ) + if s.start > tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} " + "is not at the beginning of the Cut. " + "Please apply `lhotse cut trim-to-supervisions`." + ) + if c.start + s.end > c.end + tol: + raise ValueError( + f"{c.id}: Supervision end time {c.start+s.end} is larger " + f"than cut end time {c.end}" + ) + + if s.duration != c.duration: + raise ValueError( + f"{c.id}: Cut duration {c.duration} and supervision duration " + f"{s.duration} must be the same.\n" + f"The difference causes problems in the training code : " + f"+/- 1 frame in `x`, `x_lens` in `Zipformer::forward()`.\n" + f"Did you forget to apply `trim_to_supervisions()` ?" + ) + + +def main(): + args = get_args() + + manifest = args.cutset_manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + try: + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + # Validation from K2 training + # - checks supervision start is 0 + # - checks supervision.duration is not longer than cut.duration + # - there is tolerance 2ms + validate_for_asr(cut_set) + except BaseException as e: + logging.error(str(e)) + raise + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/voxpopuli/ASR/prepare.sh b/egs/voxpopuli/ASR/prepare.sh new file mode 100755 index 000000000..7cddad756 --- /dev/null +++ b/egs/voxpopuli/ASR/prepare.sh @@ -0,0 +1,257 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -euxo pipefail + +nj=20 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/voxpopuli/raw_audios/$lang/$year +# This directory contains *.ogg files with audio downloaded and extracted from archives: +# https://dl.fbaipublicfiles.com/voxpopuli/audios/${lang}_${year}.tar +# +# - Note: the voxpopuli transcripts are downloaded to a ${tmp} folder +# as part of `lhotse prepare voxpopuli` from: +# https://dl.fbaipublicfiles.com/voxpopuli/annotations/asr/asr_${lang}.tsv.gz +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download +#dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA # BUT + +musan_dir=${dl_dir}/musan +#musan_dir=/mnt/matylda2/data/MUSAN # BUT + +# Choose value from ASR_LANGUAGES: +# +# [ "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr", +# "sk", "sl", "et", "lt" ] +# +# See ASR_LANGUAGES in: +# https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/recipes/voxpopuli.py#L54C4-L54C4 +lang=en + +task=asr + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/${lang}/lang_bpe_xxx, +# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data/${lang}". +# You can safely remove "data/${lang}" and rerun this script to regenerate it. +mkdir -p data/${lang} + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" +log "musan_dir: $musan_dir" +log "task: $task, lang: $lang" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/$release, + # you can create a symlink + # + # ln -sfv /path/to/$release $dl_dir/$release + # + if [ ! -d $dl_dir/voxpopuli/raw_audios/${lang} ]; then + lhotse download voxpopuli --subset $lang $dl_dir/voxpopuli + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $musan_dir/musan ]; then + lhotse download musan $musan_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare VoxPopuli manifest" + # We assume that you have downloaded the VoxPopuli corpus + # to $dl_dir/voxpopuli + if [ ! -e data/manifests/.voxpopuli-${task}-${lang}.done ]; then + # Warning : it requires Internet connection (it downloads transcripts to ${tmpdir}) + lhotse prepare voxpopuli --task asr --lang $lang -j $nj $dl_dir/voxpopuli data/manifests + touch data/manifests/.voxpopuli-${task}-${lang}.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + #lhotse prepare musan $dl_dir/musan data/manifests + lhotse prepare musan $musan_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Preprocess VoxPopuli manifest" + mkdir -p data/fbank + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete ]; then + # recordings + supervisions -> cutset + ./local/preprocess_voxpopuli.py --task $task --lang $lang \ + --use-original-text True + touch data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for dev and test subsets of VoxPopuli" + mkdir -p data/fbank + for dataset in "dev" "test"; do + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done ]; then + ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 50 --num-workers ${nj} \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset ${dataset} \ + --trim-to-supervisions True + touch data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done + fi + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for train set of VoxPopuli" + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-train.done ]; then + ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 100 --num-workers ${nj} \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset train \ + --trim-to-supervisions True \ + --speed-perturb True + touch data/fbank/.voxpopuli-${task}-${lang}-train.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Validate fbank manifests for VoxPopuli" + for dataset in "dev" "test" "train"; do + mkdir -p data/fbank/log/ + ./local/validate_cutset_manifest.py \ + data/fbank/voxpopuli-asr-en_cuts_${dataset}.jsonl.gz \ + 2>&1 | tee data/fbank/log/validate_voxpopuli-asr-en_cuts_${dataset}.log + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size}_${lang} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + file=$( + find "data/fbank/voxpopuli-${task}-${lang}_cuts_train.jsonl.gz" + ) + local/text_from_manifest.py $file >$lang_dir/transcript_words.txt + # gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt + + # Ensure space only appears once + #sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + #sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/words.txt ]; then + cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' > $lang_dir/words.txt + (echo '!SIL'; echo ''; echo ''; ) | + cat - $lang_dir/words.txt | sort | uniq | awk ' + BEGIN { + print " 0"; + } + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $lang_dir/words || exit 1; + mv $lang_dir/words $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi diff --git a/egs/voxpopuli/ASR/shared b/egs/voxpopuli/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/voxpopuli/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From 666d69b20d53796420593d99b0c0d6e9cd2212cc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 17 Nov 2023 18:12:59 +0800 Subject: [PATCH 048/216] Rename train2.py to avoid confusion (#1386) --- .github/scripts/run-multi-zh_hans-zipformer.sh | 4 +++- egs/aishell/ASR/prepare.sh | 5 ++--- .../{train2.py => do_not_use_it_directly.py} | 1 + egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py | 2 +- .../{train2.py => do_not_use_it_directly.py} | 1 + .../ASR/pruned_transducer_stateless7_streaming/README.md | 4 ++-- .../{train2.py => do_not_use_it_directly.py} | 1 + .../{train2.py => do_not_use_it_directly.py} | 1 + .../export-for-ncnn.py | 2 +- .../{train2.py => do_not_use_it_directly.py} | 1 + .../conv_emformer_transducer_stateless2/export-for-ncnn.py | 2 +- .../ASR/conv_emformer_transducer_stateless2/export-onnx.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/README.md | 4 ++-- .../{train2.py => do_not_use_it_directly.py} | 1 + .../export-for-ncnn-zh.py | 2 +- .../export-for-ncnn.py | 2 +- .../do_not_use_it_directly.py | 1 + .../export-for-ncnn.py | 2 +- .../pruned_transducer_stateless7_streaming_multi/train2.py | 1 - 19 files changed, 23 insertions(+), 16 deletions(-) rename egs/aishell/ASR/pruned_transducer_stateless7/{train2.py => do_not_use_it_directly.py} (99%) rename egs/aishell/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%) rename egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%) rename egs/csj/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%) rename egs/librispeech/ASR/conv_emformer_transducer_stateless2/{train2.py => do_not_use_it_directly.py} (99%) rename egs/librispeech/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py delete mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-zh_hans-zipformer.sh index dd32a94f8..cbd86a4d3 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-zh_hans-zipformer.sh @@ -51,6 +51,8 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000002.wav done +rm -rf $repo + log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ====" repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ @@ -92,4 +94,4 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav -done \ No newline at end of file +done diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index d36dc5ed3..9f73a2073 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -261,10 +261,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then fi if [ ! -f $lang_char_dir/HLG.fst ]; then - lang_phone_dir=data/lang_phone ./local/prepare_lang_fst.py \ - --lang-dir $lang_phone_dir \ - --ngram-G ./data/lm/G_3_gram.fst.txt + --lang-dir $lang_char_dir \ + --ngram-G ./data/lm/G_3_gram_char.fst.txt fi fi diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py similarity index 99% rename from egs/aishell/ASR/pruned_transducer_stateless7/train2.py rename to egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py index 057af297f..6027273b2 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py @@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() AsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py index 2a9fc57d5..39d988cd0 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py @@ -56,7 +56,7 @@ import torch.nn as nn from decoder2 import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from zipformer import Zipformer from icefall.checkpoint import ( diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py similarity index 99% rename from egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py rename to egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 88eb34104..3c13c19c6 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1233,6 +1233,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md index 991875aaa..6c20bab2c 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md @@ -4,6 +4,6 @@ See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer [./emformer.py](./emformer.py) and [./train.py](./train.py) are basically the same as -[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). -The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py) is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py similarity index 99% rename from egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py rename to egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index c09c9537c..61a3f27db 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1237,6 +1237,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() CommonVoiceAsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py similarity index 99% rename from egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py rename to egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 4c866ddd8..acde72d80 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1274,6 +1274,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() CSJAsrDataModule.add_arguments(parser) Tokenizer.add_arguments(parser) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index ebdb596a5..b210430c6 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -72,7 +72,7 @@ from pathlib import Path import torch from scaling_converter import convert_scaled_to_non_scaled from tokenizer import Tokenizer -from train2 import add_model_arguments, get_params, get_transducer_model +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py similarity index 99% rename from egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py rename to egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py index 420dc1065..d614f0914 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py @@ -1099,6 +1099,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 85dbd4661..953f95c45 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -39,8 +39,8 @@ from pathlib import Path import k2 import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index ab046557f..1e59e0858 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -61,7 +61,7 @@ import torch.nn as nn from decoder import Decoder from emformer import Emformer from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md index d3691e647..0f3c63e75 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md @@ -4,7 +4,7 @@ See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer [./emformer.py](./emformer.py) and [./train.py](./train.py) are basically the same as -[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). -The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py) is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py similarity index 99% rename from egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py rename to egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index aa6c0668a..cd26db6f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py index 07de57a86..a7d06a5dd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -68,8 +68,8 @@ from pathlib import Path import k2 import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index 9a6b31268..8f2178b1d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -66,8 +66,8 @@ from pathlib import Path import k2 import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py new file mode 120000 index 000000000..beeffaa03 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/do_not_use_it_directly.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py index 9a6b31268..8f2178b1d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -66,8 +66,8 @@ from pathlib import Path import k2 import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py deleted file mode 120000 index 3c3280b68..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7_streaming/train2.py \ No newline at end of file From 11d816d174076ec9485ab8b1d36af2592514e348 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sat, 18 Nov 2023 18:47:55 +0800 Subject: [PATCH 049/216] Add cumstomized score for hotwords (#1385) * add custom score for each hotword * Add more comments * Fix deocde * fix style * minor fixes --- .../pruned_transducer_stateless7/decode.py | 2 +- .../decode.py | 2 +- .../pruned_transducer_stateless4/decode.py | 4 +- egs/librispeech/ASR/zipformer/decode.py | 4 +- .../pruned_transducer_stateless5/decode.py | 2 +- icefall/context_graph.py | 117 +++++++++++++----- 6 files changed, 92 insertions(+), 39 deletions(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py index be58c4e43..696eea906 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py @@ -641,7 +641,7 @@ def main(): contexts_text.append(line.strip()) contexts = graph_compiler.texts_to_ids(contexts_text) context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) + context_graph.build([(c, 0.0) for c in contexts]) else: context_graph = None else: diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py index f5ae836fd..99110d6b6 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -686,7 +686,7 @@ def main(): contexts_text.append(line.strip()) contexts = graph_compiler.texts_to_ids(contexts_text) context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) + context_graph.build([(c, 0.0) for c in contexts]) else: context_graph = None else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 524366068..5195a4ef6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -927,9 +927,9 @@ def main(): if os.path.exists(params.context_file): contexts = [] for line in open(params.context_file).readlines(): - contexts.append(line.strip()) + contexts.append((sp.encode(line.strip()), 0.0)) context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) + context_graph.build(contexts) else: context_graph = None else: diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 3531d657f..339e253e6 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -1001,9 +1001,9 @@ def main(): if os.path.exists(params.context_file): contexts = [] for line in open(params.context_file).readlines(): - contexts.append(line.strip()) + contexts.append((sp.encode(line.strip()), 0.0)) context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) + context_graph.build(contexts) else: context_graph = None else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 36b8a4b67..d665f3364 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -868,7 +868,7 @@ def main(): contexts_text.append(line.strip()) contexts = graph_compiler.texts_to_ids(contexts_text) context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) + context_graph.build([(c, 0.0) for c in contexts]) else: context_graph = None else: diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 0b7c42c0b..b3d7972a8 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -84,6 +84,9 @@ class ContextGraph: context_score: The bonus score for each token(note: NOT for each word/phrase, it means longer word/phrase will have larger bonus score, they have to be matched though). + Note: This is just the default score for each token, the users can manually + specify the context_score for each word/phrase (i.e. different phrase might + have different token score). """ self.context_score = context_score self.num_nodes = 0 @@ -133,7 +136,7 @@ class ContextGraph: node.output_score += 0 if output is None else output.output_score queue.append(node) - def build(self, token_ids: List[List[int]]): + def build(self, token_ids: List[Tuple[List[int], float]]): """Build the ContextGraph from a list of token list. It first build a trie from the given token lists, then fill the fail arc for each trie node. @@ -142,26 +145,46 @@ class ContextGraph: Args: token_ids: - The given token lists to build the ContextGraph, it is a list of token list, - each token list contains the token ids for a word/phrase. The token id - could be an id of a char (modeling with single Chinese char) or an id - of a BPE (modeling with BPEs). + The given token lists to build the ContextGraph, it is a list of tuple of + token list and its customized score, the token list contains the token ids + for a word/phrase. The token id could be an id of a char + (modeling with single Chinese char) or an id of a BPE + (modeling with BPEs). The score is the total score for current token list, + 0 means using the default value (i.e. self.context_score). + + Note: The phrases would have shared states, the score of the shared states is + the maximum value among all the tokens sharing this state. """ - for tokens in token_ids: + for (tokens, score) in token_ids: node = self.root + # If has customized score using the customized token score, otherwise + # using the default score + context_score = ( + self.context_score if score == 0.0 else round(score / len(tokens), 2) + ) for i, token in enumerate(tokens): + node_next = {} if token not in node.next: self.num_nodes += 1 + node_id = self.num_nodes + token_score = context_score is_end = i == len(tokens) - 1 - node_score = node.node_score + self.context_score - node.next[token] = ContextState( - id=self.num_nodes, - token=token, - token_score=self.context_score, - node_score=node_score, - output_score=node_score if is_end else 0, - is_end=is_end, - ) + else: + # node exists, get the score of shared state. + token_score = max(context_score, node.next[token].token_score) + node_id = node.next[token].id + node_next = node.next[token].next + is_end = i == len(tokens) - 1 or node.next[token].is_end + node_score = node.node_score + token_score + node.next[token] = ContextState( + id=node_id, + token=token, + token_score=token_score, + node_score=node_score, + output_score=node_score if is_end else 0, + is_end=is_end, + ) + node.next[token].next = node_next node = node.next[token] self._fill_fail_output() @@ -343,7 +366,7 @@ class ContextGraph: return dot -if __name__ == "__main__": +def _test(queries, score): contexts_str = [ "S", "HE", @@ -355,9 +378,11 @@ if __name__ == "__main__": "THIS", "THEM", ] + + # test default score (1) contexts = [] for s in contexts_str: - contexts.append([ord(x) for x in s]) + contexts.append(([ord(x) for x in s], score)) context_graph = ContextGraph(context_score=1) context_graph.build(contexts) @@ -369,10 +394,28 @@ if __name__ == "__main__": context_graph.draw( title="Graph for: " + " / ".join(contexts_str), - filename="context_graph.pdf", + filename=f"context_graph_{score}.pdf", symbol_table=symbol_table, ) + for query, expected_score in queries.items(): + total_scores = 0 + state = context_graph.root + for q in query: + score, state = context_graph.forward_one_step(state, ord(q)) + total_scores += score + score, state = context_graph.finalize(state) + assert state.token == -1, state.token + total_scores += score + assert round(total_scores, 2) == expected_score, ( + total_scores, + expected_score, + query, + ) + + +if __name__ == "__main__": + # test default score queries = { "HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE" "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" @@ -384,17 +427,27 @@ if __name__ == "__main__": "DHRHISQ": 4, # "HIS", "S" "THEN": 2, # "HE" } - for query, expected_score in queries.items(): - total_scores = 0 - state = context_graph.root - for q in query: - score, state = context_graph.forward_one_step(state, ord(q)) - total_scores += score - score, state = context_graph.finalize(state) - assert state.token == -1, state.token - total_scores += score - assert total_scores == expected_score, ( - total_scores, - expected_score, - query, - ) + _test(queries, 0) + + # test custom score (5) + # S : 5 + # HE : 5 (2.5 + 2.5) + # SHE : 8.34 (5 + 1.67 + 1.67) + # SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1) + # HIS : 5.84 (2.5 + 1.67 + 1.67) + # HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25) + # HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1) + # THIS : 5 (1.25 + 1.25 + 1.25 + 1.25) + queries = { + "HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE" + "HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE" + "HISHE": 24.18, # "HIS", "S", "SHE", "HE" + "SHED": 18.34, # "S", "SHE", "HE" + "SHELF": 18.34, # "S", "SHE", "HE" + "HELL": 5, # "HE" + "HELLO": 13, # "HE", "HELLO" + "DHRHISQ": 10.84, # "HIS", "S" + "THEN": 5, # "HE" + } + + _test(queries, 5) From 238b45bea85deee1a07cfd0f55b485cc92f67135 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 23 Nov 2023 01:22:57 +0800 Subject: [PATCH 050/216] Libriheavy recipe (zipformer) (#1261) * initial commit for libriheavy * Data prepare pipeline * Fix train.py * Fix decode.py * Add results * minor fixes * black * black * Incorporate PR https://github.com/k2-fsa/icefall/pull/1269 --------- Co-authored-by: zr_jin --- egs/libriheavy/ASR/README.md | 6 + egs/libriheavy/ASR/RESULTS.md | 114 +- .../ASR/local/compute_fbank_libriheavy.py | 242 +++ .../ASR/local/compute_fbank_musan.py | 1 + egs/libriheavy/ASR/local/norm_text.py | 58 + egs/libriheavy/ASR/local/prepare_manifest.py | 47 + egs/libriheavy/ASR/local/train_bpe_model.py | 113 ++ egs/libriheavy/ASR/prepare.sh | 314 ++++ .../ASR/zipformer/asr_datamodule.py | 443 ++++++ egs/libriheavy/ASR/zipformer/beam_search.py | 1 + egs/libriheavy/ASR/zipformer/decode.py | 794 +++++++++ egs/libriheavy/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + egs/libriheavy/ASR/zipformer/export-onnx.py | 1 + egs/libriheavy/ASR/zipformer/export.py | 1 + .../ASR/zipformer/jit_pretrained.py | 1 + egs/libriheavy/ASR/zipformer/joiner.py | 1 + egs/libriheavy/ASR/zipformer/model.py | 1 + egs/libriheavy/ASR/zipformer/onnx_decode.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/libriheavy/ASR/zipformer/optim.py | 1 + egs/libriheavy/ASR/zipformer/pretrained.py | 1 + egs/libriheavy/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_coverter.py | 1 + egs/libriheavy/ASR/zipformer/subsampling.py | 1 + .../ASR/zipformer/text_normalization.py | 50 + egs/libriheavy/ASR/zipformer/train.py | 1415 +++++++++++++++++ egs/libriheavy/ASR/zipformer/zipformer.py | 1 + requirements-ci.txt | 1 + requirements.txt | 1 + 30 files changed, 3613 insertions(+), 2 deletions(-) create mode 100644 egs/libriheavy/ASR/README.md create mode 100755 egs/libriheavy/ASR/local/compute_fbank_libriheavy.py create mode 120000 egs/libriheavy/ASR/local/compute_fbank_musan.py create mode 100755 egs/libriheavy/ASR/local/norm_text.py create mode 100755 egs/libriheavy/ASR/local/prepare_manifest.py create mode 100755 egs/libriheavy/ASR/local/train_bpe_model.py create mode 100755 egs/libriheavy/ASR/prepare.sh create mode 100644 egs/libriheavy/ASR/zipformer/asr_datamodule.py create mode 120000 egs/libriheavy/ASR/zipformer/beam_search.py create mode 100644 egs/libriheavy/ASR/zipformer/decode.py create mode 120000 egs/libriheavy/ASR/zipformer/decoder.py create mode 120000 egs/libriheavy/ASR/zipformer/encoder_interface.py create mode 120000 egs/libriheavy/ASR/zipformer/export-onnx.py create mode 120000 egs/libriheavy/ASR/zipformer/export.py create mode 120000 egs/libriheavy/ASR/zipformer/jit_pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/joiner.py create mode 120000 egs/libriheavy/ASR/zipformer/model.py create mode 120000 egs/libriheavy/ASR/zipformer/onnx_decode.py create mode 120000 egs/libriheavy/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/optim.py create mode 120000 egs/libriheavy/ASR/zipformer/pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/scaling.py create mode 120000 egs/libriheavy/ASR/zipformer/scaling_coverter.py create mode 120000 egs/libriheavy/ASR/zipformer/subsampling.py create mode 100644 egs/libriheavy/ASR/zipformer/text_normalization.py create mode 100644 egs/libriheavy/ASR/zipformer/train.py create mode 120000 egs/libriheavy/ASR/zipformer/zipformer.py diff --git a/egs/libriheavy/ASR/README.md b/egs/libriheavy/ASR/README.md new file mode 100644 index 000000000..2498d017f --- /dev/null +++ b/egs/libriheavy/ASR/README.md @@ -0,0 +1,6 @@ +# Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context + +Libriheavy is a labeled version of [Librilight](https://arxiv.org/pdf/1912.07875.pdf). Please refer to our repository [k2-fsa/libriheavy](https://github.com/k2-fsa/libriheavy) for more details. We also have a paper: *Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context*, [Preprint available on arxiv](https://arxiv.org/abs/2309.08105). + + +See [RESULTS](./RESULTS.md) for the results for icefall recipes. diff --git a/egs/libriheavy/ASR/RESULTS.md b/egs/libriheavy/ASR/RESULTS.md index 4fbedad98..513bbf72e 100644 --- a/egs/libriheavy/ASR/RESULTS.md +++ b/egs/libriheavy/ASR/RESULTS.md @@ -1,6 +1,116 @@ -## Results +# Results -### Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) +## zipformer (zipformer + pruned stateless transducer) + +See for more details. + +[zipformer](./zipformer) + +### Non-streaming + +#### Training on normalized text, i.e. Upper case without punctuation + +##### normal-scaled model, number of model parameters: 65805511, i.e., 65.81 M + +You can find a pretrained model, training logs at: + + +Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set), +exp_small_subset(small set). + +Results of models: + +| training set | decoding method | librispeech clean | librispeech other | libriheavy clean | libriheavy other | comment | +|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------| +| small | greedy search | 4.19 | 9.99 | 4.75 | 10.25 |--epoch 90 --avg 20 | +| small | modified beam search| 4.05 | 9.89 | 4.68 | 10.01 |--epoch 90 --avg 20 | +| medium | greedy search | 2.39 | 4.85 | 2.90 | 6.6 |--epoch 60 --avg 20 | +| medium | modified beam search| 2.35 | 4.82 | 2.90 | 6.57 |--epoch 60 --avg 20 | +| large | greedy search | 1.67 | 3.32 | 2.24 | 5.61 |--epoch 16 --avg 3 | +| large | modified beam search| 1.62 | 3.36 | 2.20 | 5.57 |--epoch 16 --avg 3 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +python ./zipformer/train.py \ + --world-size 4 \ + --master-port 12365 \ + --exp-dir zipformer/exp \ + --num-epochs 60 \ # 16 for large; 90 for small + --lr-hours 15000 \ # 20000 for large; 5000 for small + --use-fp16 1 \ + --start-epoch 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --max-duration 1000 \ + --subset medium +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 16 \ + --avg 3 \ + --exp-dir zipformer/exp \ + --max-duration 1000 \ + --causal 0 \ + --decoding-method $m +done +``` + +#### Training on full formatted text, i.e. with casing and punctuation + +##### normal-scaled model, number of model parameters: 66074067 , i.e., 66M + +You can find a pretrained model, training logs at: + + +Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set), +exp_small_subset(small set). + +Results of models: + +| training set | decoding method | libriheavy clean (WER) | libriheavy other (WER) | libriheavy clean (CER) | libriheavy other (CER) | comment | +|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------| +| small | modified beam search| 13.04 | 19.54 | 4.51 | 7.90 |--epoch 88 --avg 41 | +| medium | modified beam search| 9.84 | 13.39 | 3.02 | 5.10 |--epoch 50 --avg 15 | +| large | modified beam search| 7.76 | 11.32 | 2.41 | 4.22 |--epoch 16 --avg 2 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +python ./zipformer/train.py \ + --world-size 4 \ + --master-port 12365 \ + --exp-dir zipformer/exp \ + --num-epochs 60 \ # 16 for large; 90 for small + --lr-hours 15000 \ # 20000 for large; 10000 for small + --use-fp16 1 \ + --train-with-punctuation 1 \ + --start-epoch 1 \ + --bpe-model data/lang_punc_bpe_756/bpe.model \ + --max-duration 1000 \ + --subset medium +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 16 \ + --avg 3 \ + --exp-dir zipformer/exp \ + --max-duration 1000 \ + --causal 0 \ + --decoding-method $m +done +``` + +## Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) #### [zipformer_prompt_asr](./zipformer_prompt_asr) diff --git a/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py new file mode 100755 index 000000000..010531db2 --- /dev/null +++ b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the Libriheavy dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, +) + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-dir", + type=str, + help="""The source directory that contains raw manifests. + """, + default="data/manifests", + ) + + parser.add_argument( + "--fbank-dir", + type=str, + help="""Fbank output dir + """, + default="data/fbank", + ) + + parser.add_argument( + "--subset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Whether to use speed perturbation.", + ) + + parser.add_argument( + "--use-splits", + type=str2bool, + default=False, + help="Whether to compute fbank on splits.", + ) + + parser.add_argument( + "--num-splits", + type=int, + help="""The number of splits of the medium and large subset. + Only needed when --use-splits is true.""", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="""Process pieces starting from this number (inclusive). + Only needed when --use-splits is true.""", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="""Stop processing pieces until this number (exclusive). + Only needed when --use-splits is true.""", + ) + + return parser.parse_args() + + +def compute_fbank_libriheavy(args): + src_dir = Path(args.manifest_dir) + output_dir = Path(args.fbank_dir) + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + subset = args.subset + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + output_cuts_path = output_dir / f"libriheavy_cuts_{subset}.jsonl.gz" + if output_cuts_path.exists(): + logging.info(f"{output_cuts_path} exists - skipping") + return + + input_cuts_path = src_dir / f"libriheavy_cuts_{subset}.jsonl.gz" + assert input_cuts_path.exists(), f"{input_cuts_path} does not exist!" + logging.info(f"Loading {input_cuts_path}") + cut_set = CutSet.from_file(input_cuts_path) + + logging.info("Computing features") + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/libriheavy_feats_{subset}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + logging.info(f"Saving to {output_cuts_path}") + cut_set.to_file(output_cuts_path) + + +def compute_fbank_libriheavy_splits(args): + num_splits = args.num_splits + subset = args.subset + src_dir = f"{args.manifest_dir}/libriheavy_{subset}_split" + src_dir = Path(src_dir) + output_dir = f"{args.fbank_dir}/libriheavy_{subset}_split" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + num_digits = 8 # num_digits is fixed by lhotse split-lazy + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = src_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" + if not raw_cuts_path.is_file(): + logging.info(f"{raw_cuts_path} does not exist - skipping it") + continue + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Computing features") + if (output_dir / f"libriheavy_feats_{subset}_{idx}.lca").exists(): + logging.info(f"Removing {output_dir}/libriheavy_feats_{subset}_{idx}.lca") + os.remove(output_dir / f"libriheavy_feats_{subset}_{idx}.lca") + + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/libriheavy_feats_{subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + overwrite=True, + ) + + logging.info("About to split cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + if args.use_splits: + assert args.num_splits is not None, "Please provide num_splits" + compute_fbank_libriheavy_splits(args) + else: + compute_fbank_libriheavy(args) diff --git a/egs/libriheavy/ASR/local/compute_fbank_musan.py b/egs/libriheavy/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/libriheavy/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/local/norm_text.py b/egs/libriheavy/ASR/local/norm_text.py new file mode 100755 index 000000000..c2fc0d92d --- /dev/null +++ b/egs/libriheavy/ASR/local/norm_text.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import codecs +import sys + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + help="""Path to the input text. + """, + ) + return parser.parse_args() + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def main(): + args = get_args() + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer) + line = f.readline() + while line: + print(remove_punc_to_upper(line)) + line = f.readline() + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py new file mode 100755 index 000000000..42f392cae --- /dev/null +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import json +import sys +from pathlib import Path + + +def simple_cleanup(text: str) -> str: + table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") + text = text.translate(table) + return text.strip() + + +# Assign text of the supervisions and remove unnecessary entries. +def main(): + assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" + fname = Path(sys.argv[1]).name + oname = Path(sys.argv[2]) / fname + with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: + for line in fin: + cut = json.loads(line) + cut["supervisions"][0]["text"] = simple_cleanup( + cut["supervisions"][0]["custom"]["texts"][0] + ) + del cut["supervisions"][0]["custom"] + del cut["custom"] + fout.write((json.dumps(cut) + "\n").encode()) + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/local/train_bpe_model.py b/egs/libriheavy/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..19caf43ab --- /dev/null +++ b/egs/libriheavy/ASR/local/train_bpe_model.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--byte-fallback", + action="store_true", + help="""Whether to enable byte_fallback when training bpe.""", + ) + + parser.add_argument( + "--character-coverage", + type=float, + default=1.0, + help="Character coverage in vocabulary.", + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=args.character_coverage, + user_defined_symbols=user_defined_symbols, + byte_fallback=args.byte_fallback, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh new file mode 100755 index 000000000..af7e3c5b0 --- /dev/null +++ b/egs/libriheavy/ASR/prepare.sh @@ -0,0 +1,314 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 +export CUDA_VISIBLE_DEVICES="" + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/librilight +# You can find small, medium, large, etc. inside it. +# +# - $dl_dir/libriheavy +# You can find libriheavy_cuts_small.jsonl.gz, libriheavy_cuts_medium.jsonl.gz, etc. inside it. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +fbank_dir=data/fbank +manifests_dir=data/manifests + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download audio data." + # If you have pre-downloaded it to /path/to/librilight, + # you can create a symlink + # + # ln -sfv /path/to/librilight $dl_dir/librilight + # + mkdir -p $dl_dir/librilight + for subset in small medium large; do + log "Downloading ${subset} subset." + if [ ! -d $dl_dir/librilight/${subset} ]; then + wget -P $dl_dir/librilight -c https://dl.fbaipublicfiles.com/librilight/data/${subset}.tar + tar xf $dl_dir/librilight/${subset}.tar -C $dl_dir/librilight + else + log "Skipping download, ${subset} subset exists." + fi + done +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download manifests from huggingface." + + # If you have pre-downloaded it to /path/to/libriheavy, + # you can create a symlink + # + # ln -sfv /path/to/libriheavy $dl_dir/libriheavy + # + mkdir -p $dl_dir/libriheavy + for subset in small medium large dev test_clean test_other; do + if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then + log "Downloading ${subset} subset." + wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz + else + log "Skipping download, ${subset} subset exists." + fi + done + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Download manifests from modelscope" + mkdir -p $dl_dir/libriheavy + if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_small.jsonl.gz ]; then + cd $dl_dir/libriheavy + GIT_LFS_SKIP_SMUDGE=1 git clone https://www.modelscope.cn/datasets/pkufool/Libriheavy.git + cd Libriheavy + git lfs pull --exclude "raw/*" + mv *.jsonl.gz ../ + cd .. + rm -rf Libriheavy + cd ../../ + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p $manifests_dir + if [ ! -e $manifests_dir/.musan.done ]; then + lhotse prepare musan $dl_dir/musan $manifests_dir + touch $manifests_dir/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare Libriheavy manifests" + mkdir -p $manifests_dir + for subset in small medium large dev test_clean test_other; do + if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then + log "Prepare manifest for subset : ${subset}" + ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir + fi + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p $fbank_dir + if [ ! -e $fbank_dir/.musan.done ]; then + ./local/compute_fbank_musan.py + touch $fbank_dir/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for small subset and validation subsets" + for subset in test_clean test_other dev small; do + log "Computing $subset subset." + if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then + ./local/compute_fbank_libriheavy.py \ + --manifest-dir ${manifests_dir} \ + --subset ${subset} \ + --fbank-dir $fbank_dir \ + --num-workers $nj + fi + done +fi + +num_per_split=8000 +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split medium and large subsets." + for subset in medium large; do + log "Spliting subset : $subset" + split_dir=$manifests_dir/libriheavy_${subset}_split + mkdir -p $split_dir + if [ ! -e $split_dir/.split_completed ]; then + lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split + touch $split_dir/.split_completed + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for medium and large subsets" + mkdir -p $fbank_dir + chunk_size=20 + for subset in medium large; do + if [ $subset == "large" ]; then + chunk_size=200 + fi + num_splits=$(find $manifests_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz" | wc -l) + if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then + for i in $(seq 0 1 6); do + start=$(( i * $chunk_size )) + end=$(( (i+1) * $chunk_size )) + ./local/compute_fbank_libriheavy.py \ + --manifest-dir ${manifests_dir} \ + --use-splits 1 \ + --subset ${subset} \ + --fbank-dir $fbank_dir \ + --num-splits $num_splits \ + --num-workers $nj \ + --start $start \ + --stop $end & + done + wait + touch $fbank_dir/.libriheavy.${subset}.done + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Combine features for medium and large subsets." + for subset in medium large; do + log "Combining $subset subset." + if [ ! -f $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then + pieces=$(find $fbank_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz") + lhotse combine $pieces $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz + fi + done +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Train BPE model for normalized text" + + if [ ! -f data/texts ]; then + gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ + | ./local/norm_text.py > data/texts + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + cp data/texts $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + done +fi + + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Train BPE model for unnormalized text" + if [ ! -f data/punc_texts ]; then + gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts + fi + for vocab_size in ${vocab_sizes[@]}; do + new_vacab_size = $(($vocab_size + 256)) + lang_dir=data/lang_punc_bpe_${new_vocab_size} + mkdir -p $lang_dir + + cp data/punc_texts $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --byte-fallback \ + --vocab-size ${new_vocab_size} \ + --byte-fallback \ + --character-coverage 0.99 \ + --transcript $lang_dir/text + fi + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Prepare language model for normalized text" + + for subset in small medium large; do + if [ ! -f $manifests_dir/texts_${subset} ]; then + gunzip -c $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ + | ./local/norm_text.py > $manifests_dir/texts_${subset} + fi + done + + mkdir -p data/lm + if [ ! -f data/lm/text ]; then + cat $manifests_dir/texts_small $manifests_dir/texts_medium $manifests_dir/texts_large > data/lm/text + fi + + (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ + > data/lm/words.txt + + cat data/lm/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ + | awk '{print $1" "NR+3}' >> data/lm/words.txt + + num_lines=$(< data/lm/words.txt wc -l) + (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ + >> data/lm/words.txt + + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text data/lm/text \ + -lm data/lm/3-gram.unpruned.arpa + fi + + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table=data/lm/words.txt \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt + fi +fi + diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..df761c1b8 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,443 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriHeavyAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--subset", + type=str, + default="S", + help="""The subset to be used. Should be S, M or L. Note: S subset + includes libriheavy_cuts_small.jsonl.gz, M subset includes + libriheavy_cuts_small.jsonl.gz and libriheavy_cuts_medium.jsonl.gz, + L subset includes libriheavy_cuts_small.jsonl.gz, + libriheavy_cuts_medium.jsonl.gz and libriheavy_cuts_large.jsonl.gz. + """, + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_small_cuts(self) -> CutSet: + logging.info("About to get small subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz" + ) + + @lru_cache() + def train_medium_cuts(self) -> CutSet: + logging.info("About to get medium subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" + ) + + @lru_cache() + def train_large_cuts(self) -> CutSet: + logging.info("About to get large subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get the test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get the test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz" + ) diff --git a/egs/libriheavy/ASR/zipformer/beam_search.py b/egs/libriheavy/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py new file mode 100644 index 000000000..1928e2635 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/decode.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from text_normalization import remove_punc_to_upper +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--train-with-punctuation", + type=str2bool, + default=False, + help="""Set to True, if the model was trained on texts with casing + and punctuation.""", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=False, + help="""Upper case and remove all chars except ' and - + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + this_batch = [] + if params.post_normalization and params.train_with_punctuation: + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = remove_punc_to_upper(ref_text).split() + hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[f"{name}_norm"].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + def normalize_text(c: Cut): + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c + + test_clean_cuts = libriheavy.test_clean_cuts() + test_other_cuts = libriheavy.test_other_cuts() + + if not params.train_with_punctuation: + test_clean_cuts = test_clean_cuts.map(normalize_text) + test_other_cuts = test_other_cuts.map(normalize_text) + + test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts) + test_other_dl = libriheavy.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer/decoder.py b/egs/libriheavy/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/encoder_interface.py b/egs/libriheavy/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/export-onnx.py b/egs/libriheavy/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/export.py b/egs/libriheavy/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/jit_pretrained.py b/egs/libriheavy/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/joiner.py b/egs/libriheavy/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/model.py b/egs/libriheavy/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/onnx_decode.py b/egs/libriheavy/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/onnx_pretrained.py b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/optim.py b/egs/libriheavy/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/pretrained.py b/egs/libriheavy/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/scaling.py b/egs/libriheavy/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/scaling_coverter.py b/egs/libriheavy/ASR/zipformer/scaling_coverter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/scaling_coverter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/subsampling.py b/egs/libriheavy/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py new file mode 100644 index 000000000..92590769c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/text_normalization.py @@ -0,0 +1,50 @@ +from num2words import num2words + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def word_normalization(word: str) -> str: + # 1. Use full word for some abbreviation + # 2. Convert digits to english words + # 3. Convert ordinal number to english words + if word == "MRS": + return "MISSUS" + if word == "MR": + return "MISTER" + if word == "ST": + return "SAINT" + if word == "ECT": + return "ET CETERA" + + if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH + word = num2words(word[:-2], to="ordinal") + word = word.replace("-", " ") + + if word.isnumeric(): + num = int(word) + if num > 1500 and num < 2030: + word = num2words(word, to="year") + else: + word = num2words(word) + word = word.replace("-", " ") + return word.upper() + + +def text_normalization(text: str) -> str: + text = text.upper() + return " ".join([word_normalization(x) for x in text.split()]) + + +if __name__ == "__main__": + assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK" + assert ( + text_normalization("Hello Mrs st 21st world 3rd she 99th MR") + == "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER" + ) diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py new file mode 100644 index 000000000..c97da4a11 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -0,0 +1,1415 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from text_normalization import remove_punc_to_upper +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-hours", + type=float, + default=30000, + help="""Number of hours that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--train-with-punctuation", + type=str2bool, + default=False, + help="If True, the training text will include casing and punctuation.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + # Use the number of hours of speech to adjust the learning rate + scheduler.step_epoch( + params.batch_idx_train * params.max_duration * params.world_size / 3600 + ) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_hours) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def normalize_text(c: Cut): + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 2.0 or c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + libriheavy = LibriHeavyAsrDataModule(args) + + train_cuts = libriheavy.train_small_cuts() + if params.subset == "M" or params.subset == "L": + train_cuts += libriheavy.train_medium_cuts() + if params.subset == "L": + train_cuts += libriheavy.train_large_cuts() + + if not params.train_with_punctuation: + train_cuts = train_cuts.map(normalize_text) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = libriheavy.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = libriheavy.dev_cuts() + + if not params.train_with_punctuation: + valid_cuts = valid_cuts.map(normalize_text) + + valid_dl = libriheavy.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer/zipformer.py b/egs/libriheavy/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/requirements-ci.txt b/requirements-ci.txt index e1232a768..6c74f688c 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -17,6 +17,7 @@ six git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 kaldialign==0.7.1 +num2words sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 diff --git a/requirements.txt b/requirements.txt index 5a8326619..9502fcbd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ kaldifst kaldilm kaldialign +num2words kaldi-decoder sentencepiece>=0.1.96 tensorboard From ae67f75e9c429d35e8a84d6d70cc8050eae37c86 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 26 Nov 2023 10:04:15 +0800 Subject: [PATCH 051/216] a bilingual recipe similar to the `multi-zh_hans` (#1265) --- ...rmer.sh => run-multi-corpora-zipformer.sh} | 38 + ...er.yml => run-multi-corpora-zipformer.yml} | 10 +- egs/multi_zh_en/ASR/README.md | 19 + egs/multi_zh_en/ASR/RESULTS.md | 44 + egs/multi_zh_en/ASR/local/compile_lg.py | 1 + egs/multi_zh_en/ASR/local/prepare_char.py | 1 + .../ASR/local/prepare_for_bpe_model.py | 65 + egs/multi_zh_en/ASR/local/prepare_lang.py | 1 + .../ASR/local/prepare_lang_bbpe.py | 1 + egs/multi_zh_en/ASR/local/prepare_lang_bpe.py | 1 + egs/multi_zh_en/ASR/local/prepare_words.py | 1 + egs/multi_zh_en/ASR/local/text2segments.py | 1 + egs/multi_zh_en/ASR/local/text2token.py | 1 + egs/multi_zh_en/ASR/local/train_bbpe_model.py | 1 + .../ASR/local/validate_bpe_lexicon.py | 1 + egs/multi_zh_en/ASR/prepare.sh | 149 ++ egs/multi_zh_en/ASR/shared | 1 + .../ASR/zipformer/asr_datamodule.py | 385 +++++ egs/multi_zh_en/ASR/zipformer/beam_search.py | 1 + egs/multi_zh_en/ASR/zipformer/decode.py | 851 ++++++++++ egs/multi_zh_en/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/multi_zh_en/ASR/zipformer/export-onnx.py | 1 + egs/multi_zh_en/ASR/zipformer/export.py | 541 +++++++ .../ASR/zipformer/generate_averaged_model.py | 193 +++ .../ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_ctc.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/multi_zh_en/ASR/zipformer/joiner.py | 1 + egs/multi_zh_en/ASR/zipformer/model.py | 1 + .../ASR/zipformer/multi_dataset.py | 247 +++ egs/multi_zh_en/ASR/zipformer/onnx_check.py | 1 + egs/multi_zh_en/ASR/zipformer/onnx_decode.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/multi_zh_en/ASR/zipformer/optim.py | 1 + egs/multi_zh_en/ASR/zipformer/pretrained.py | 378 +++++ egs/multi_zh_en/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 1 + egs/multi_zh_en/ASR/zipformer/subsampling.py | 1 + egs/multi_zh_en/ASR/zipformer/train.py | 1416 +++++++++++++++++ egs/multi_zh_en/ASR/zipformer/zipformer.py | 1 + 45 files changed, 4363 insertions(+), 5 deletions(-) rename .github/scripts/{run-multi-zh_hans-zipformer.sh => run-multi-corpora-zipformer.sh} (66%) rename .github/workflows/{run-multi-zh_hans-zipformer.yml => run-multi-corpora-zipformer.yml} (91%) create mode 100644 egs/multi_zh_en/ASR/README.md create mode 100644 egs/multi_zh_en/ASR/RESULTS.md create mode 120000 egs/multi_zh_en/ASR/local/compile_lg.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_char.py create mode 100755 egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang_bpe.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_words.py create mode 120000 egs/multi_zh_en/ASR/local/text2segments.py create mode 120000 egs/multi_zh_en/ASR/local/text2token.py create mode 120000 egs/multi_zh_en/ASR/local/train_bbpe_model.py create mode 120000 egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/multi_zh_en/ASR/prepare.sh create mode 120000 egs/multi_zh_en/ASR/shared create mode 100644 egs/multi_zh_en/ASR/zipformer/asr_datamodule.py create mode 120000 egs/multi_zh_en/ASR/zipformer/beam_search.py create mode 100755 egs/multi_zh_en/ASR/zipformer/decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/decoder.py create mode 120000 egs/multi_zh_en/ASR/zipformer/encoder_interface.py create mode 120000 egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/export-onnx.py create mode 100755 egs/multi_zh_en/ASR/zipformer/export.py create mode 100755 egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/joiner.py create mode 120000 egs/multi_zh_en/ASR/zipformer/model.py create mode 100644 egs/multi_zh_en/ASR/zipformer/multi_dataset.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_check.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/optim.py create mode 100755 egs/multi_zh_en/ASR/zipformer/pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/scaling.py create mode 120000 egs/multi_zh_en/ASR/zipformer/scaling_converter.py create mode 120000 egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py create mode 120000 egs/multi_zh_en/ASR/zipformer/streaming_decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/subsampling.py create mode 100755 egs/multi_zh_en/ASR/zipformer/train.py create mode 120000 egs/multi_zh_en/ASR/zipformer/zipformer.py diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-corpora-zipformer.sh similarity index 66% rename from .github/scripts/run-multi-zh_hans-zipformer.sh rename to .github/scripts/run-multi-corpora-zipformer.sh index cbd86a4d3..90f859f43 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-corpora-zipformer.sh @@ -95,3 +95,41 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav done + +rm -rf $repo + +cd ../../../egs/multi_zh_en/ASR +log "==== Test icefall-asr-zipformer-multi-zh-en-2023-11-22 ====" +repo_url=https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22/ + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +./zipformer/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \ + --method greedy_search \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav + +for method in modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav +done + +rm -rf $repo diff --git a/.github/workflows/run-multi-zh_hans-zipformer.yml b/.github/workflows/run-multi-corpora-zipformer.yml similarity index 91% rename from .github/workflows/run-multi-zh_hans-zipformer.yml rename to .github/workflows/run-multi-corpora-zipformer.yml index 72c0775a7..38f7eb908 100644 --- a/.github/workflows/run-multi-zh_hans-zipformer.yml +++ b/.github/workflows/run-multi-corpora-zipformer.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-multi-zh_hans-zipformer +name: run-multi-corpora-zipformer on: push: @@ -24,12 +24,12 @@ on: types: [labeled] concurrency: - group: run_multi-zh_hans_zipformer-${{ github.ref }} + group: run_multi-corpora_zipformer-${{ github.ref }} cancel-in-progress: true jobs: - run_multi-zh_hans_zipformer: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' + run_multi-corpora_zipformer: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' || github.event.label.name == 'multi-corpora' runs-on: ${{ matrix.os }} strategy: matrix: @@ -81,4 +81,4 @@ jobs: export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-multi-zh_hans-zipformer.sh + .github/scripts/run-multi-corpora-zipformer.sh diff --git a/egs/multi_zh_en/ASR/README.md b/egs/multi_zh_en/ASR/README.md new file mode 100644 index 000000000..29341571d --- /dev/null +++ b/egs/multi_zh_en/ASR/README.md @@ -0,0 +1,19 @@ +# Introduction + +This recipe includes scripts for training Zipformer model using both English and Chinese datasets. + +# Included Training Sets + +1. LibriSpeech (English) +2. AiShell-2 (Chinese) +3. TAL-CSASR (Code-Switching, Chinese and English) + +|Datset| Number of hours| URL| +|---|---:|---| +|**TOTAL**|2,547|---| +|LibriSpeech|960|https://www.openslr.org/12/| +|AiShell-2|1,000|http://www.aishelltech.com/aishell_2| +|TAL-CSASR|587|https://ai.100tal.com/openData/voice| + + + diff --git a/egs/multi_zh_en/ASR/RESULTS.md b/egs/multi_zh_en/ASR/RESULTS.md new file mode 100644 index 000000000..3562d6ac3 --- /dev/null +++ b/egs/multi_zh_en/ASR/RESULTS.md @@ -0,0 +1,44 @@ +## Results + +### Zh-En datasets bpe-based training results (Non-streaming) on Zipformer model + +This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1265) in icefall. + +#### Non-streaming (Byte-Level BPE vocab_size=2000) + +Best results (num of params : ~69M): + +The training command: + +``` +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 35 \ + --use-fp16 1 \ + --max-duration 1000 \ + --num-workers 8 +``` + +The decoding command: + +``` +for method in greedy_search modified_beam_search fast_beam_search; do + ./zipformer/decode.py \ + --epoch 34 \ + --avg 19 \ + --decoding-method $method +done +``` + +Word Error Rates (WERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model (# tokens is 2000). + +| Datasets | TAL-CSASR | TAL-CSASR | AiShell-2 | AiShell-2 | LibriSpeech | LibriSpeech | +|----------------------|-----------|-----------|-----------|-----------|-------------|-------------| +| Zipformer WER (%) | dev | test | dev | test | test-clean | test-other | +| greedy_search | 6.65 | 6.69 | 6.57 | 7.03 | 2.43 | 5.70 | +| modified_beam_search | 6.46 | 6.51 | 6.18 | 6.60 | 2.41 | 5.57 | +| fast_beam_search | 6.57 | 6.68 | 6.40 | 6.74 | 2.40 | 5.56 | + +Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22, which is trained on LibriSpeech 960-hour training set (with speed perturbation), TAL-CSASR training set (with speed perturbation) and AiShell-2 (w/o speed perturbation). + + diff --git a/egs/multi_zh_en/ASR/local/compile_lg.py b/egs/multi_zh_en/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/multi_zh_en/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_char.py b/egs/multi_zh_en/ASR/local/prepare_char.py new file mode 120000 index 000000000..42743b544 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_char.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py new file mode 100755 index 000000000..00514e6bb --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script tokenizes the training transcript by CJK characters +# and saves the result to transcript_chars.txt, which is used +# to train the BPE model later. + +import argparse +from pathlib import Path + +from tqdm.auto import tqdm + +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Output directory. + The generated transcript_chars.txt is saved to this directory. + """, + ) + + parser.add_argument( + "--text", + type=str, + help="Training transcript.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + text = Path(args.text) + + assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" + + transcript_path = lang_dir / "transcript_chars.txt" + + with open(text, "r", encoding="utf-8") as fin: + with open(transcript_path, "w+", encoding="utf-8") as fout: + for line in tqdm(fin): + fout.write(tokenize_by_CJK_char(line) + "\n") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/local/prepare_lang.py b/egs/multi_zh_en/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py new file mode 120000 index 000000000..9a0b44642 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_lang_bbpe.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_words.py b/egs/multi_zh_en/ASR/local/prepare_words.py new file mode 120000 index 000000000..ef2b4eaf3 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_words.py @@ -0,0 +1 @@ +../../../aishell2/ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/text2segments.py b/egs/multi_zh_en/ASR/local/text2segments.py new file mode 120000 index 000000000..7d68a39c3 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/text2segments.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/text2segments.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/text2token.py b/egs/multi_zh_en/ASR/local/text2token.py new file mode 120000 index 000000000..ce5cfd537 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/text2token.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/train_bbpe_model.py b/egs/multi_zh_en/ASR/local/train_bbpe_model.py new file mode 120000 index 000000000..7fb4a9f9d --- /dev/null +++ b/egs/multi_zh_en/ASR/local/train_bbpe_model.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/train_bbpe_model.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh new file mode 100755 index 000000000..9f2be5a5c --- /dev/null +++ b/egs/multi_zh_en/ASR/prepare.sh @@ -0,0 +1,149 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +vocab_sizes=( + 2000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +log "Dataset: musan" +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Soft link fbank of musan" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4" + exit 1 + fi +fi + +log "Dataset: LibriSpeech" +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Soft link fbank of LibriSpeech" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +log "Dataset: AiShell-2" +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Soft link fbank of AiShell-2" + mkdir -p data/fbank + if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts*) . + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats*) . + cd ../.. + else + log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare Byte BPE based lang" + mkdir -p data/fbank + if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then + log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi + + if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6" + exit 1 + fi + + cd data/ + if [ ! -d ./lang_char ]; then + ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) . + fi + if [ ! -d ./lang_bpe_500 ]; then + ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . + fi + cd ../ + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + mkdir -p $lang_dir + + cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \ + > $lang_dir/text + + if [ ! -f $lang_dir/transcript_chars.txt ]; then + ./local/prepare_for_bpe_model.py \ + --lang-dir ./$lang_dir \ + --text $lang_dir/text + fi + + if [ ! -f $lang_dir/text_words_segmentation ]; then + python3 ./local/text2segments.py \ + --input-file ./data/lang_char/text \ + --output-file $lang_dir/text_words_segmentation + + cat ./data/lang_bpe_500/transcript_words.txt \ + >> $lang_dir/text_words_segmentation + + cat ./data/lang_char/text \ + >> $lang_dir/text + fi + + cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt + + if [ ! -f $lang_dir/words.txt ]; then + python3 ./local/prepare_words.py \ + --input-file $lang_dir/words_no_ids.txt \ + --output-file $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bbpe.model ]; then + ./local/train_bbpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bbpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bbpe.model + fi + done +fi + diff --git a/egs/multi_zh_en/ASR/shared b/egs/multi_zh_en/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/multi_zh_en/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..be6e94472 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,385 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=300.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl diff --git a/egs/multi_zh_en/ASR/zipformer/beam_search.py b/egs/multi_zh_en/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py new file mode 100755 index 000000000..e21e8f052 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/decode.py @@ -0,0 +1,851 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_2000/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-tal-csasr", + type=str2bool, + default=False, + help="Whether to use TAL-CSASR training data.", + ) + + parser.add_argument( + "--use-librispeech", + type=str2bool, + default=False, + help="Whether to use LibriSpeech training data.", + ) + + parser.add_argument( + "--use-aishell2", + type=str2bool, + default=False, + help="Whether to use Aishell-2 training data.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode( + byte_encode(tokenize_by_CJK_char(supervisions["text"])) + ), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [tokenize_by_CJK_char(str(text)).split() for text in texts] + # print(texts) + # exit() + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/decoder.py b/egs/multi_zh_en/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/encoder_interface.py b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx.py b/egs/multi_zh_en/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export.py b/egs/multi_zh_en/ASR/zipformer/export.py new file mode 100755 index 000000000..fbd9ce0dd --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +import re +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, str2bool + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bbpe_2000/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py new file mode 100755 index 000000000..68111fad7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the checkpoint exp_dir/epoch-xxx.pt +./zipformer/generate_averaged_model.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(2) use the checkpoint exp_dir/checkpoint-iter.pt +./zipformer/generate_averaged_model.py \ + --iter 22000 \ + --avg 5 \ + --exp-dir ./zipformer/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path + +import k2 +import torch +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + symbol_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = symbol_table[""] + params.unk_id = symbol_table[""] + params.vocab_size = len(symbol_table) + + print("About to create model") + model = get_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 000000000..9a8da5844 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/joiner.py b/egs/multi_zh_en/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/model.py b/egs/multi_zh_en/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py new file mode 100644 index 000000000..1155a3dcc --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py @@ -0,0 +1,247 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Dict + +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, args: argparse.Namespace): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - aishell2_cuts_train.jsonl.gz + """ + self.fbank_dir = Path(args.manifest_dir) + self.use_tal_csasr = args.use_tal_csasr + self.use_librispeech = args.use_librispeech + self.use_aishell2 = args.use_aishell2 + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # AISHELL-2 + if self.use_aishell2: + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) + + # TAL-CSASR + if self.use_tal_csasr: + logging.info("Loading TAL-CSASR in lazy mode") + tal_csasr_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz" + ) + + # LibriSpeech + if self.use_librispeech: + logging.info("Loading LibriSpeech in lazy mode") + train_clean_100_cuts = self.train_clean_100_cuts() + train_clean_360_cuts = self.train_clean_360_cuts() + train_other_500_cuts = self.train_other_500_cuts() + + if self.use_tal_csasr and self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + tal_csasr_cuts, + weights=[ + len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + len(tal_csasr_cuts), + ], + ) + elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + weights=[ + len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + ], + ) + elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + tal_csasr_cuts, + weights=[ + len(aishell_2_cuts), + len(tal_csasr_cuts), + ], + ) + elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2: + return CutSet.mux( + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + tal_csasr_cuts, + weights=[ + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + len(tal_csasr_cuts), + ], + ) + else: + raise NotImplementedError( + f"""Not implemented for + use_aishell2: {self.use_aishell2} + use_librispeech: {self.use_librispeech} + use_tal_csasr: {self.use_tal_csasr}""" + ) + + def dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + # AISHELL-2 + logging.info("Loading Aishell-2 DEV set in lazy mode") + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # LibriSpeech + dev_clean_cuts = self.dev_clean_cuts() + dev_other_cuts = self.dev_other_cuts() + + logging.info("Loading TAL-CSASR set in lazy mode") + tal_csasr_dev_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" + ) + + return CutSet.mux( + aishell2_dev_cuts, + dev_clean_cuts, + dev_other_cuts, + tal_csasr_dev_cuts, + weights=[ + len(aishell2_dev_cuts), + len(dev_clean_cuts), + len(dev_other_cuts), + len(tal_csasr_dev_cuts), + ], + ) + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + # AISHELL-2 + if self.use_aishell2: + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # LibriSpeech + if self.use_librispeech: + test_clean_cuts = self.test_clean_cuts() + test_other_cuts = self.test_other_cuts() + + logging.info("Loading TAL-CSASR set in lazy mode") + tal_csasr_test_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz" + ) + tal_csasr_dev_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" + ) + + test_cuts = { + "tal_csasr_test": tal_csasr_test_cuts, + "tal_csasr_dev": tal_csasr_dev_cuts, + } + + if self.use_aishell2: + test_cuts.update( + { + "aishell-2_test": aishell2_test_cuts, + "aishell-2_dev": aishell2_dev_cuts, + } + ) + if self.use_librispeech: + test_cuts.update( + { + "librispeech_test_clean": test_clean_cuts, + "librispeech_test_other": test_other_cuts, + } + ) + return test_cuts + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_check.py b/egs/multi_zh_en/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_decode.py b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/optim.py b/egs/multi_zh_en/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py new file mode 100755 index 000000000..676272e1f --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 23 \ + --avg 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 23 \ + --avg 1 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp/epoch-xx.pt`. + +Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to byte-level bpe model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh_en/ASR/zipformer/scaling.py b/egs/multi_zh_en/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/scaling_converter.py b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py new file mode 120000 index 000000000..13fd02a78 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/subsampling.py b/egs/multi_zh_en/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py new file mode 100755 index 000000000..310c8fe59 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -0,0 +1,1416 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from multi_dataset import MultiDataset +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + tokenize_by_CJK_char, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_2000/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-tal-csasr", + type=str2bool, + default=False, + help="Whether to use TAL-CSASR training data.", + ) + + parser.add_argument( + "--use-librispeech", + type=str2bool, + default=False, + help="Whether to use LibriSpeech training data.", + ) + + parser.add_argument( + "--use-aishell2", + type=str2bool, + default=False, + help="Whether to use Aishell-2 training data.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args) + + train_cuts = multi_dataset.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_cuts = train_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = data_module.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = multi_dataset.dev_cuts() + valid_dl = data_module.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/zipformer.py b/egs/multi_zh_en/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 0622dea30deacf2680dcca0549f7a05c0b965066 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Wed, 29 Nov 2023 21:28:38 +0800 Subject: [PATCH 052/216] Add a TTS recipe VITS on LJSpeech dataset (#1372) * first commit * replace phonimizer with g2p * use Conformer as text encoder * modify training script, clean codes * rename directory * convert text to tokens in data preparation stage * fix tts_datamodule.py * support onnx export and testing the exported onnx model * add doc * add README.md * fix style --- .flake8 | 2 +- docs/source/recipes/TTS/index.rst | 7 + docs/source/recipes/TTS/ljspeech/vits.rst | 113 +++ docs/source/recipes/index.rst | 3 +- .../TTS/local/compute_spectrogram_ljspeech.py | 106 ++ .../TTS/local/display_manifest_statistics.py | 73 ++ egs/ljspeech/TTS/local/prepare_token_file.py | 104 ++ .../TTS/local/prepare_tokens_ljspeech.py | 59 ++ egs/ljspeech/TTS/local/validate_manifest.py | 70 ++ egs/ljspeech/TTS/prepare.sh | 117 +++ egs/ljspeech/TTS/shared/parse_options.sh | 1 + egs/ljspeech/TTS/vits/README.md | 3 + egs/ljspeech/TTS/vits/duration_predictor.py | 194 ++++ egs/ljspeech/TTS/vits/export-onnx.py | 261 +++++ egs/ljspeech/TTS/vits/flow.py | 312 ++++++ egs/ljspeech/TTS/vits/generator.py | 531 ++++++++++ egs/ljspeech/TTS/vits/hifigan.py | 933 ++++++++++++++++++ egs/ljspeech/TTS/vits/infer.py | 233 +++++ egs/ljspeech/TTS/vits/loss.py | 336 +++++++ .../TTS/vits/monotonic_align/__init__.py | 81 ++ .../TTS/vits/monotonic_align/core.pyx | 51 + .../TTS/vits/monotonic_align/setup.py | 31 + egs/ljspeech/TTS/vits/posterior_encoder.py | 117 +++ egs/ljspeech/TTS/vits/residual_coupling.py | 229 +++++ egs/ljspeech/TTS/vits/test_onnx.py | 123 +++ egs/ljspeech/TTS/vits/text_encoder.py | 662 +++++++++++++ egs/ljspeech/TTS/vits/tokenizer.py | 106 ++ egs/ljspeech/TTS/vits/train.py | 893 +++++++++++++++++ egs/ljspeech/TTS/vits/transform.py | 218 ++++ egs/ljspeech/TTS/vits/tts_datamodule.py | 325 ++++++ egs/ljspeech/TTS/vits/utils.py | 265 +++++ egs/ljspeech/TTS/vits/vits.py | 610 ++++++++++++ egs/ljspeech/TTS/vits/wavenet.py | 349 +++++++ pyproject.toml | 1 + 34 files changed, 7517 insertions(+), 2 deletions(-) create mode 100644 docs/source/recipes/TTS/index.rst create mode 100644 docs/source/recipes/TTS/ljspeech/vits.rst create mode 100755 egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py create mode 100755 egs/ljspeech/TTS/local/display_manifest_statistics.py create mode 100755 egs/ljspeech/TTS/local/prepare_token_file.py create mode 100755 egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py create mode 100755 egs/ljspeech/TTS/local/validate_manifest.py create mode 100755 egs/ljspeech/TTS/prepare.sh create mode 120000 egs/ljspeech/TTS/shared/parse_options.sh create mode 100644 egs/ljspeech/TTS/vits/README.md create mode 100644 egs/ljspeech/TTS/vits/duration_predictor.py create mode 100755 egs/ljspeech/TTS/vits/export-onnx.py create mode 100644 egs/ljspeech/TTS/vits/flow.py create mode 100644 egs/ljspeech/TTS/vits/generator.py create mode 100644 egs/ljspeech/TTS/vits/hifigan.py create mode 100755 egs/ljspeech/TTS/vits/infer.py create mode 100644 egs/ljspeech/TTS/vits/loss.py create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/__init__.py create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/core.pyx create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/setup.py create mode 100644 egs/ljspeech/TTS/vits/posterior_encoder.py create mode 100644 egs/ljspeech/TTS/vits/residual_coupling.py create mode 100755 egs/ljspeech/TTS/vits/test_onnx.py create mode 100644 egs/ljspeech/TTS/vits/text_encoder.py create mode 100644 egs/ljspeech/TTS/vits/tokenizer.py create mode 100755 egs/ljspeech/TTS/vits/train.py create mode 100644 egs/ljspeech/TTS/vits/transform.py create mode 100644 egs/ljspeech/TTS/vits/tts_datamodule.py create mode 100644 egs/ljspeech/TTS/vits/utils.py create mode 100644 egs/ljspeech/TTS/vits/vits.py create mode 100644 egs/ljspeech/TTS/vits/wavenet.py diff --git a/.flake8 b/.flake8 index 410cb5482..cf276d0ba 100644 --- a/.flake8 +++ b/.flake8 @@ -15,7 +15,7 @@ per-file-ignores = egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203 egs/librispeech/ASR/zipformer/*.py: E501, E203 egs/librispeech/ASR/RESULTS.md: E999, - + egs/ljspeech/TTS/vits/*.py: E501, E203 # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst new file mode 100644 index 000000000..aa891c072 --- /dev/null +++ b/docs/source/recipes/TTS/index.rst @@ -0,0 +1,7 @@ +TTS +====== + +.. toctree:: + :maxdepth: 2 + + ljspeech/vits diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst new file mode 100644 index 000000000..385fd3c70 --- /dev/null +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -0,0 +1,113 @@ +VITS +=============== + +This tutorial shows you how to train an VITS model +with the `LJSpeech `_ dataset. + +.. note:: + + The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/ljspeech/TTS + $ ./prepare.sh + +To run stage 1 to stage 5, use + +.. code-block:: bash + + $ ./prepare.sh --stage 1 --stop_stage 5 + + +Build Monotonic Alignment Search +-------------------------------- + +.. code-block:: bash + + $ cd vits/monotonic_align + $ python setup.py build_ext --inplace + $ cd ../../ + + +Training +-------- + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0,1,2,3" + $ ./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 500 + +.. note:: + + You can adjust the hyper-parameters to control the size of the VITS model and + the training configurations. For more details, please run ``./vits/train.py --help``. + +.. note:: + + The training can take a long time (usually a couple of days). + +Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``. + + +Inference +--------- + +The inference part uses checkpoints saved by the training part, so you have to run the +training part first. It will save the ground-truth and generated wavs to the directory +``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``. + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0" + $ ./vits/infer.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 500 + +.. note:: + + For more details, please run ``./vits/infer.py --help``. + + +Export models +------------- + +Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: +``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. + +.. code-block:: bash + + $ ./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +You can test the exported ONNX model with: + +.. code-block:: bash + + $ ./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following link: + + - ``_ diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 7265e1cf6..8df61f0d0 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -2,7 +2,7 @@ Recipes ======= This page contains various recipes in ``icefall``. -Currently, only speech recognition recipes are provided. +Currently, we provide recipes for speech recognition, language model, and speech synthesis. We may add recipes for other tasks as well in the future. @@ -16,3 +16,4 @@ We may add recipes for other tasks as well in the future. Non-streaming-ASR/index Streaming-ASR/index RNN-LM/index + TTS/index diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py new file mode 100755 index 000000000..97c9008fc --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated spectrogram features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_ljspeech(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(4, os.cpu_count()) + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_ljspeech() diff --git a/egs/ljspeech/TTS/local/display_manifest_statistics.py b/egs/ljspeech/TTS/local/display_manifest_statistics.py new file mode 100755 index 000000000..93f0044f0 --- /dev/null +++ b/egs/ljspeech/TTS/local/display_manifest_statistics.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in vits/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz" + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Cut statistics: + ╒═══════════════════════════╤══════════╕ + │ Cuts count: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Total duration (hh:mm:ss) │ 23:55:18 │ + ├───────────────────────────┼──────────┤ + │ mean │ 6.6 │ + ├───────────────────────────┼──────────┤ + │ std │ 2.2 │ + ├───────────────────────────┼──────────┤ + │ min │ 1.1 │ + ├───────────────────────────┼──────────┤ + │ 25% │ 5.0 │ + ├───────────────────────────┼──────────┤ + │ 50% │ 6.8 │ + ├───────────────────────────┼──────────┤ + │ 75% │ 8.4 │ + ├───────────────────────────┼──────────┤ + │ 99% │ 10.0 │ + ├───────────────────────────┼──────────┤ + │ 99.5% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ 99.9% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ max │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ Recordings available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Features available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Supervisions available: │ 13100 │ + ╘═══════════════════════════╧══════════╛ +""" diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py new file mode 100755 index 000000000..df976804a --- /dev/null +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict + +from lhotse import load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-file", + type=Path, + default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"), + help="Path to the manifest file", + ) + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the tokens", + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_token2id(manifest_file: Path) -> Dict[str, int]: + """Return a dict that maps token to IDs.""" + extra_tokens = [ + "", # 0 for blank + "", # 1 for sos and eos symbols. + "", # 2 for OOV + ] + all_tokens = set() + + cut_set = load_manifest(manifest_file) + + for cut in cut_set: + # Each cut only contain one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + for t in cut.tokens: + all_tokens.add(t) + + all_tokens = extra_tokens + list(all_tokens) + + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} + return token2id + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + manifest_file = Path(args.manifest_file) + out_file = Path(args.tokens) + + token2id = get_token2id(manifest_file) + write_mapping(out_file, token2id) diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py new file mode 100755 index 000000000..fcd0137a0 --- /dev/null +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest + + +def prepare_tokens_ljspeech(): + output_dir = Path("data/spectrogram") + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].normalized_text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_ljspeech() diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py new file mode 100755 index 000000000..68159ae03 --- /dev/null +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh new file mode 100755 index 000000000..8ee40896e --- /dev/null +++ b/egs/ljspeech/TTS/prepare.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=1 +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # The directory $dl_dir/LJSpeech-1.1 will contain: + # - wavs, which contains the audio files + # - metadata.csv, which provides the transcript text for each audio clip + + # If you have pre-downloaded it to /path/to/LJSpeech-1.1, you can create a symlink + # + # ln -sfv /path/to/LJSpeech-1.1 $dl_dir/LJSpeech-1.1 + # + if [ ! -d $dl_dir/LJSpeech-1.1 ]; then + lhotse download ljspeech $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LJSpeech manifest" + # We assume that you have downloaded the LJSpeech corpus + # to $dl_dir/LJSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.ljspeech.done ]; then + lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests + touch data/manifests/.ljspeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute spectrogram for LJSpeech" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.ljspeech.done ]; then + ./local/compute_spectrogram_ljspeech.py + touch data/spectrogram/.ljspeech.done + fi + + if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then + log "Validating data/spectrogram for LJSpeech" + python3 ./local/validate_manifest.py \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare phoneme tokens for LJSpeech" + if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then + ./local/prepare_tokens_ljspeech.py + mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" + if [ ! -e data/spectrogram/.ljspeech_split.done ]; then + lhotse subset --last 600 \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_test.jsonl.gz + + rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_train.jsonl.gz + touch data/spectrogram/.ljspeech_split.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" + # We assume you have installed g2p_en and espnet_tts_frontend. + # If not, please install them with: + # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py \ + --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \ + --tokens data/tokens.txt + fi +fi + + diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh new file mode 120000 index 000000000..e4665e7de --- /dev/null +++ b/egs/ljspeech/TTS/shared/parse_options.sh @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md new file mode 100644 index 000000000..1141326b9 --- /dev/null +++ b/egs/ljspeech/TTS/vits/README.md @@ -0,0 +1,3 @@ +See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials. + +Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29. diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py new file mode 100644 index 000000000..c29a28479 --- /dev/null +++ b/egs/ljspeech/TTS/vits/duration_predictor.py @@ -0,0 +1,194 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Stochastic duration predictor modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from flow import ( + ConvFlow, + DilatedDepthSeparableConv, + ElementwiseAffineFlow, + FlipFlow, + LogFlow, +) + + +class StochasticDurationPredictor(torch.nn.Module): + """Stochastic duration predictor module. + + This is a module of stochastic duration predictor described in `Conditional + Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + channels: int = 192, + kernel_size: int = 3, + dropout_rate: float = 0.5, + flows: int = 4, + dds_conv_layers: int = 3, + global_channels: int = -1, + ): + """Initialize StochasticDurationPredictor module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + dropout_rate (float): Dropout rate. + flows (int): Number of flows. + dds_conv_layers (int): Number of conv layers in DDS conv. + global_channels (int): Number of global conditioning channels. + + """ + super().__init__() + + self.pre = torch.nn.Conv1d(channels, channels, 1) + self.dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.proj = torch.nn.Conv1d(channels, channels, 1) + + self.log_flow = LogFlow() + self.flows = torch.nn.ModuleList() + self.flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.flows += [FlipFlow()] + + self.post_pre = torch.nn.Conv1d(1, channels, 1) + self.post_dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.post_proj = torch.nn.Conv1d(channels, channels, 1) + self.post_flows = torch.nn.ModuleList() + self.post_flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.post_flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.post_flows += [FlipFlow()] + + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + w: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + noise_scale: float = 1.0, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T_text). + x_mask (Tensor): Mask tensor (B, 1, T_text). + w (Optional[Tensor]): Duration tensor (B, 1, T_text). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1) + inverse (bool): Whether to inverse the flow. + noise_scale (float): Noise scale value. + + Returns: + Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,). + If inverse, log-duration tensor (B, 1, T_text). + + """ + x = x.detach() # stop gradient + x = self.pre(x) + if g is not None: + x = x + self.global_conv(g.detach()) # stop gradient + x = self.dds(x, x_mask) + x = self.proj(x) * x_mask + + if not inverse: + assert w is not None, "w must be provided." + h_w = self.post_pre(w) + h_w = self.post_dds(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn( + w.size(0), + 2, + w.size(2), + ).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + logdet_tot_q = 0.0 + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in self.flows: + z, logdet = flow(z, x_mask, g=x, inverse=inverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # (B,) + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn( + x.size(0), + 2, + x.size(2), + ).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, inverse=inverse) + z0, z1 = z.split(1, 1) + logw = z0 + return logw diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py new file mode 100755 index 000000000..154de4bf4 --- /dev/null +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate two files inside vits/exp: + - vits-epoch-1000.onnx + - vits-epoch-1000.int8.onnx (quantizated model) + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + }, + ) + + meta_data = { + "model_type": "VITS", + "version": "1", + "model_author": "k2-fsa", + "comment": "VITS generator", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model = model.generator + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py new file mode 100644 index 000000000..206bd5e3e --- /dev/null +++ b/egs/ljspeech/TTS/vits/flow.py @@ -0,0 +1,312 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Basic Flow modules used in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional, Tuple, Union + +import torch + +from transform import piecewise_rational_quadratic_transform + + +class FlipFlow(torch.nn.Module): + """Flip flow module.""" + + def forward( + self, x: torch.Tensor, *args, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Flipped tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + x = torch.flip(x, [1]) + if not inverse: + logdet = x.new_zeros(x.size(0)) + return x, logdet + else: + return x + + +class LogFlow(torch.nn.Module): + """Log flow module.""" + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + inverse: bool = False, + eps: float = 1e-5, + **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + inverse (bool): Whether to inverse the flow. + eps (float): Epsilon for log. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = torch.log(torch.clamp_min(x, eps)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class ElementwiseAffineFlow(torch.nn.Module): + """Elementwise affine flow module.""" + + def __init__(self, channels: int): + """Initialize ElementwiseAffineFlow module. + + Args: + channels (int): Number of channels. + + """ + super().__init__() + self.channels = channels + self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1))) + self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1))) + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_lengths (Tensor): Length tensor (B,). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class Transpose(torch.nn.Module): + """Transpose module for torch.nn.Sequential().""" + + def __init__(self, dim1: int, dim2: int): + """Initialize Transpose module.""" + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Transpose.""" + return x.transpose(self.dim1, self.dim2) + + +class DilatedDepthSeparableConv(torch.nn.Module): + """Dilated depth-separable conv module.""" + + def __init__( + self, + channels: int, + kernel_size: int, + layers: int, + dropout_rate: float = 0.0, + eps: float = 1e-5, + ): + """Initialize DilatedDepthSeparableConv module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + dropout_rate (float): Dropout rate. + eps (float): Epsilon for layer norm. + + """ + super().__init__() + + self.convs = torch.nn.ModuleList() + for i in range(layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Conv1d( + channels, + channels, + 1, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Dropout(dropout_rate), + ) + ] + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + if g is not None: + x = x + g + for f in self.convs: + y = f(x * x_mask) + x = x + y + return x * x_mask + + +class ConvFlow(torch.nn.Module): + """Convolutional flow module.""" + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + layers: int, + bins: int = 10, + tail_bound: float = 5.0, + ): + """Initialize ConvFlow module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + bins (int): Number of bins. + tail_bound (float): Tail bound value. + + """ + super().__init__() + self.half_channels = in_channels // 2 + self.hidden_channels = hidden_channels + self.bins = bins + self.tail_bound = tail_bound + + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.dds_conv = DilatedDepthSeparableConv( + hidden_channels, + kernel_size, + layers, + dropout_rate=0.0, + ) + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * (bins * 3 - 1), + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, 1) + h = self.input_conv(xa) + h = self.dds_conv(h, x_mask, g=g) + h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T) + + b, c, t = xa.shape + # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1) + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) + + # TODO(kan-bayashi): Understand this calculation + denom = math.sqrt(self.hidden_channels) + unnorm_widths = h[..., : self.bins] / denom + unnorm_heights = h[..., self.bins : 2 * self.bins] / denom + unnorm_derivatives = h[..., 2 * self.bins :] + xb, logdet_abs = piecewise_rational_quadratic_transform( + xb, + unnorm_widths, + unnorm_heights, + unnorm_derivatives, + inverse=inverse, + tails="linear", + tail_bound=self.tail_bound, + ) + x = torch.cat([xa, xb], 1) * x_mask + logdet = torch.sum(logdet_abs * x_mask, [1, 2]) + if not inverse: + return x, logdet + else: + return x diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py new file mode 100644 index 000000000..efb0e254c --- /dev/null +++ b/egs/ljspeech/TTS/vits/generator.py @@ -0,0 +1,531 @@ +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Generator module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + +from duration_predictor import StochasticDurationPredictor +from hifigan import HiFiGANGenerator +from posterior_encoder import PosteriorEncoder +from residual_coupling import ResidualAffineCouplingBlock +from text_encoder import TextEncoder +from utils import get_random_segments + + +class VITSGenerator(torch.nn.Module): + """Generator module in VITS, `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + aux_channels: int = 513, + hidden_channels: int = 192, + spks: Optional[int] = None, + langs: Optional[int] = None, + spk_embed_dim: Optional[int] = None, + global_channels: int = -1, + segment_size: int = 32, + text_encoder_attention_heads: int = 2, + text_encoder_ffn_expand: int = 4, + text_encoder_cnn_module_kernel: int = 5, + text_encoder_blocks: int = 6, + text_encoder_dropout_rate: float = 0.1, + decoder_kernel_size: int = 7, + decoder_channels: int = 512, + decoder_upsample_scales: List[int] = [8, 8, 2, 2], + decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + decoder_resblock_kernel_sizes: List[int] = [3, 7, 11], + decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_weight_norm_in_decoder: bool = True, + posterior_encoder_kernel_size: int = 5, + posterior_encoder_layers: int = 16, + posterior_encoder_stacks: int = 1, + posterior_encoder_base_dilation: int = 1, + posterior_encoder_dropout_rate: float = 0.0, + use_weight_norm_in_posterior_encoder: bool = True, + flow_flows: int = 4, + flow_kernel_size: int = 5, + flow_base_dilation: int = 1, + flow_layers: int = 4, + flow_dropout_rate: float = 0.0, + use_weight_norm_in_flow: bool = True, + use_only_mean_in_flow: bool = True, + stochastic_duration_predictor_kernel_size: int = 3, + stochastic_duration_predictor_dropout_rate: float = 0.5, + stochastic_duration_predictor_flows: int = 4, + stochastic_duration_predictor_dds_conv_layers: int = 3, + ): + """Initialize VITS generator module. + + Args: + vocabs (int): Input vocabulary size. + aux_channels (int): Number of acoustic feature channels. + hidden_channels (int): Number of hidden channels. + spks (Optional[int]): Number of speakers. If set to > 1, assume that the + sids will be provided as the input and use sid embedding layer. + langs (Optional[int]): Number of languages. If set to > 1, assume that the + lids will be provided as the input and use sid embedding layer. + spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, + assume that spembs will be provided as the input. + global_channels (int): Number of global conditioning channels. + segment_size (int): Segment size for decoder. + text_encoder_attention_heads (int): Number of heads in conformer block + of text encoder. + text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block + of text encoder. + text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder. + text_encoder_blocks (int): Number of conformer blocks in text encoder. + text_encoder_dropout_rate (float): Dropout rate in conformer block of + text encoder. + decoder_kernel_size (int): Decoder kernel size. + decoder_channels (int): Number of decoder initial channels. + decoder_upsample_scales (List[int]): List of upsampling scales in decoder. + decoder_upsample_kernel_sizes (List[int]): List of kernel size for + upsampling layers in decoder. + decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks + in decoder. + decoder_resblock_dilations (List[List[int]]): List of list of dilations for + resblocks in decoder. + use_weight_norm_in_decoder (bool): Whether to apply weight normalization in + decoder. + posterior_encoder_kernel_size (int): Posterior encoder kernel size. + posterior_encoder_layers (int): Number of layers of posterior encoder. + posterior_encoder_stacks (int): Number of stacks of posterior encoder. + posterior_encoder_base_dilation (int): Base dilation of posterior encoder. + posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder. + use_weight_norm_in_posterior_encoder (bool): Whether to apply weight + normalization in posterior encoder. + flow_flows (int): Number of flows in flow. + flow_kernel_size (int): Kernel size in flow. + flow_base_dilation (int): Base dilation in flow. + flow_layers (int): Number of layers in flow. + flow_dropout_rate (float): Dropout rate in flow + use_weight_norm_in_flow (bool): Whether to apply weight normalization in + flow. + use_only_mean_in_flow (bool): Whether to use only mean in flow. + stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic + duration predictor. + stochastic_duration_predictor_dropout_rate (float): Dropout rate in + stochastic duration predictor. + stochastic_duration_predictor_flows (int): Number of flows in stochastic + duration predictor. + stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv + layers in stochastic duration predictor. + + """ + super().__init__() + self.segment_size = segment_size + self.text_encoder = TextEncoder( + vocabs=vocabs, + d_model=hidden_channels, + num_heads=text_encoder_attention_heads, + dim_feedforward=hidden_channels * text_encoder_ffn_expand, + cnn_module_kernel=text_encoder_cnn_module_kernel, + num_layers=text_encoder_blocks, + dropout=text_encoder_dropout_rate, + ) + self.decoder = HiFiGANGenerator( + in_channels=hidden_channels, + out_channels=1, + channels=decoder_channels, + global_channels=global_channels, + kernel_size=decoder_kernel_size, + upsample_scales=decoder_upsample_scales, + upsample_kernel_sizes=decoder_upsample_kernel_sizes, + resblock_kernel_sizes=decoder_resblock_kernel_sizes, + resblock_dilations=decoder_resblock_dilations, + use_weight_norm=use_weight_norm_in_decoder, + ) + self.posterior_encoder = PosteriorEncoder( + in_channels=aux_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + kernel_size=posterior_encoder_kernel_size, + layers=posterior_encoder_layers, + stacks=posterior_encoder_stacks, + base_dilation=posterior_encoder_base_dilation, + global_channels=global_channels, + dropout_rate=posterior_encoder_dropout_rate, + use_weight_norm=use_weight_norm_in_posterior_encoder, + ) + self.flow = ResidualAffineCouplingBlock( + in_channels=hidden_channels, + hidden_channels=hidden_channels, + flows=flow_flows, + kernel_size=flow_kernel_size, + base_dilation=flow_base_dilation, + layers=flow_layers, + global_channels=global_channels, + dropout_rate=flow_dropout_rate, + use_weight_norm=use_weight_norm_in_flow, + use_only_mean=use_only_mean_in_flow, + ) + # TODO(kan-bayashi): Add deterministic version as an option + self.duration_predictor = StochasticDurationPredictor( + channels=hidden_channels, + kernel_size=stochastic_duration_predictor_kernel_size, + dropout_rate=stochastic_duration_predictor_dropout_rate, + flows=stochastic_duration_predictor_flows, + dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, + global_channels=global_channels, + ) + + self.upsample_factor = int(np.prod(decoder_upsample_scales)) + self.spks = None + if spks is not None and spks > 1: + assert global_channels > 0 + self.spks = spks + self.global_emb = torch.nn.Embedding(spks, global_channels) + self.spk_embed_dim = None + if spk_embed_dim is not None and spk_embed_dim > 0: + assert global_channels > 0 + self.spk_embed_dim = spk_embed_dim + self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels) + self.langs = None + if langs is not None and langs > 1: + assert global_channels > 0 + self.langs = langs + self.lang_emb = torch.nn.Embedding(langs, global_channels) + + # delayed import + from monotonic_align import maximum_path + + self.maximum_path = maximum_path + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + ]: + """Calculate forward propagation. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). + Tensor: Duration negative log-likelihood (NLL) tensor (B,). + Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). + Tensor: Segments start index tensor (B,). + Tensor: Text mask tensor (B, 1, T_text). + Tensor: Feature mask tensor (B, 1, T_feats). + tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + - Tensor: Posterior encoder hidden representation (B, H, T_feats). + - Tensor: Flow hidden representation (B, H, T_feats). + - Tensor: Expanded text encoder projected mean (B, H, T_feats). + - Tensor: Expanded text encoder projected scale (B, H, T_feats). + - Tensor: Posterior encoder projected mean (B, H, T_feats). + - Tensor: Posterior encoder projected scale (B, H, T_feats). + + """ + # forward text encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + + # calculate global conditioning + g = None + if self.spks is not None: + # speaker one-hot vector embedding: (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # language one-hot vector embedding: (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = ( + self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ) + .unsqueeze(1) + .detach() + ) + + # forward duration predictor + w = attn.sum(2) # (B, 1, T_text) + dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) + dur_nll = dur_nll / torch.sum(x_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + # get random segments + z_segments, z_start_idxs = get_random_segments( + z, + feats_lengths, + self.segment_size, + ) + + # forward decoder with random segments + wav = self.decoder(z_segments, g=g) + + return ( + wav, + dur_nll, + attn, + z_start_idxs, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) + + def inference( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: Optional[torch.Tensor] = None, + feats_lengths: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + dur: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text,). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats,). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided, + skip the prediction of durations (i.e., teacher forcing). + noise_scale (float): Noise scale parameter for flow. + noise_scale_dur (float): Noise scale parameter for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length of acoustic feature sequence. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Tensor: Generated waveform tensor (B, T_wav). + Tensor: Monotonic attention weight tensor (B, T_feats, T_text). + Tensor: Duration tensor (B, T_text). + + """ + # encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + x_mask = x_mask.to(x.dtype) + g = None + if self.spks is not None: + # (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + if use_teacher_forcing: + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ).unsqueeze(1) + dur = attn.sum(2) # (B, 1, T_text) + + # forward decoder with random segments + wav = self.decoder(z * y_mask, g=g) + else: + # duration + if dur is None: + logw = self.duration_predictor( + x, + x_mask, + g=g, + inverse=True, + noise_scale=noise_scale_dur, + ) + w = torch.exp(logw) * x_mask * alpha + dur = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() + y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device) + y_mask = y_mask.to(x.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = self._generate_path(dur, attn_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul( + attn.squeeze(1), + m_p.transpose(1, 2), + ).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul( + attn.squeeze(1), + logs_p.transpose(1, 2), + ).transpose(1, 2) + + # decoder + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, inverse=True) + wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) + + return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1) + + def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate path a.k.a. monotonic attention. + + Args: + dur (Tensor): Duration tensor (B, 1, T_text). + mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text). + + Returns: + Tensor: Path tensor (B, 1, T_feats, T_text). + + """ + b, _, t_y, t_x = mask.shape + cum_dur = torch.cumsum(dur, -1) + cum_dur_flat = cum_dur.view(b * t_x) + path = torch.arange(t_y, dtype=dur.dtype, device=dur.device) + path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) + # path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + path = path.view(b, t_x, t_y).to(dtype=torch.float) + # path will be like (t_x = 3, t_y = 5): + # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], + # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], + # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] + path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + # path = path.to(dtype=mask.dtype) + return path.unsqueeze(1).transpose(2, 3) * mask diff --git a/egs/ljspeech/TTS/vits/hifigan.py b/egs/ljspeech/TTS/vits/hifigan.py new file mode 100644 index 000000000..589ac30f6 --- /dev/null +++ b/egs/ljspeech/TTS/vits/hifigan.py @@ -0,0 +1,933 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFi-GAN Modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import copy +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F + + +class HiFiGANGenerator(torch.nn.Module): + """HiFiGAN generator module.""" + + def __init__( + self, + in_channels: int = 80, + out_channels: int = 1, + channels: int = 512, + global_channels: int = -1, + kernel_size: int = 7, + upsample_scales: List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_additional_convs: bool = True, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + ): + """Initialize HiFiGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + channels (int): Number of hidden representation channels. + global_channels (int): Number of global conditioning channels. + kernel_size (int): Kernel size of initial and final conv layer. + upsample_scales (List[int]): List of upsampling scales. + upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. + resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. + resblock_dilations (List[List[int]]): List of list of dilations for residual + blocks. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + + """ + super().__init__() + + # check hyperparameters are valid + assert kernel_size % 2 == 1, "Kernel size must be odd number." + assert len(upsample_scales) == len(upsample_kernel_sizes) + assert len(resblock_dilations) == len(resblock_kernel_sizes) + + # define modules + self.upsample_factor = int(np.prod(upsample_scales) * out_channels) + self.num_upsamples = len(upsample_kernel_sizes) + self.num_blocks = len(resblock_kernel_sizes) + self.input_conv = torch.nn.Conv1d( + in_channels, + channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + self.upsamples = torch.nn.ModuleList() + self.blocks = torch.nn.ModuleList() + for i in range(len(upsample_kernel_sizes)): + assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ConvTranspose1d( + channels // (2**i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + ), + ) + ] + for j in range(len(resblock_kernel_sizes)): + self.blocks += [ + ResidualBlock( + kernel_size=resblock_kernel_sizes[j], + channels=channels // (2 ** (i + 1)), + dilations=resblock_dilations[j], + bias=bias, + use_additional_convs=use_additional_convs, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + ) + ] + self.output_conv = torch.nn.Sequential( + # NOTE(kan-bayashi): follow official implementation but why + # using different slope parameter here? (0.1 vs. 0.01) + torch.nn.LeakyReLU(), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.Tanh(), + ) + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + c = self.input_conv(c) + if g is not None: + c = c + self.global_conv(g) + for i in range(self.num_upsamples): + c = self.upsamples[i](c) + cs = 0.0 # initialize + for j in range(self.num_blocks): + cs += self.blocks[i * self.num_blocks + j](c) + c = cs / self.num_blocks + c = self.output_conv(c) + + return c + + def reset_parameters(self): + """Reset parameters. + + This initialization follows the official implementation manner. + https://github.com/jik876/hifi-gan/blob/master/models.py + + """ + + def _reset_parameters(m: torch.nn.Module): + if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + m.weight.data.normal_(0.0, 0.01) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def inference( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Perform inference. + + Args: + c (torch.Tensor): Input tensor (T, in_channels). + g (Optional[Tensor]): Global conditioning tensor (global_channels, 1). + + Returns: + Tensor: Output tensor (T ** upsample_factor, out_channels). + + """ + if g is not None: + g = g.unsqueeze(0) + c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g) + return c.squeeze(0).transpose(1, 0) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in HiFiGAN.""" + + def __init__( + self, + kernel_size: int = 3, + channels: int = 512, + dilations: List[int] = [1, 3, 5], + bias: bool = True, + use_additional_convs: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels for convolution layer. + dilations (List[int]): List of dilation factors. + use_additional_convs (bool): Whether to use additional convolution layers. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + + """ + super().__init__() + self.use_additional_convs = use_additional_convs + self.convs1 = torch.nn.ModuleList() + if use_additional_convs: + self.convs2 = torch.nn.ModuleList() + assert kernel_size % 2 == 1, "Kernel size must be odd number." + for dilation in dilations: + self.convs1 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + bias=bias, + padding=(kernel_size - 1) // 2 * dilation, + ), + ) + ] + if use_additional_convs: + self.convs2 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + bias=bias, + padding=(kernel_size - 1) // 2, + ), + ) + ] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + for idx in range(len(self.convs1)): + xt = self.convs1[idx](x) + if self.use_additional_convs: + xt = self.convs2[idx](xt) + x = xt + x + return x + + +class HiFiGANPeriodDiscriminator(torch.nn.Module): + """HiFiGAN period discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + period: int = 3, + kernel_sizes: List[int] = [5, 3], + channels: int = 32, + downsample_scales: List[int] = [3, 3, 3, 3, 1], + max_downsample_channels: int = 1024, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initialize HiFiGANPeriodDiscriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + period (int): Period. + kernel_sizes (list): Kernel sizes of initial conv layers and the final conv + layer. + channels (int): Number of initial channels. + downsample_scales (List[int]): List of downsampling scales. + max_downsample_channels (int): Number of maximum downsampling channels. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. + If set to true, it will be applied to all of the conv layers. + + """ + super().__init__() + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." + assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." + + self.period = period + self.convs = torch.nn.ModuleList() + in_chs = in_channels + out_chs = channels + for downsample_scale in downsample_scales: + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv2d( + in_chs, + out_chs, + (kernel_sizes[0], 1), + (downsample_scale, 1), + padding=((kernel_sizes[0] - 1) // 2, 0), + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Use downsample_scale + 1? + out_chs = min(out_chs * 4, max_downsample_channels) + self.output_conv = torch.nn.Conv2d( + out_chs, + out_channels, + (kernel_sizes[1] - 1, 1), + 1, + padding=((kernel_sizes[1] - 1) // 2, 0), + ) + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + if use_spectral_norm: + self.apply_spectral_norm() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + + Returns: + list: List of each layer's tensors. + + """ + # transform 1d to 2d -> (B, C, T/P, P) + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t += n_pad + x = x.view(b, c, t // self.period, self.period) + + # forward conv + outs = [] + for layer in self.convs: + x = layer(x) + outs += [x] + x = self.output_conv(x) + x = torch.flatten(x, 1, -1) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + +class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN multi-period discriminator module.""" + + def __init__( + self, + periods: List[int] = [2, 3, 5, 7, 11], + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initialize HiFiGANMultiPeriodDiscriminator module. + + Args: + periods (List[int]): List of periods. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + for period in periods: + params = copy.deepcopy(discriminator_params) + params["period"] = period + self.discriminators += [HiFiGANPeriodDiscriminator(**params)] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each + layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + + return outs + + +class HiFiGANScaleDiscriminator(torch.nn.Module): + """HiFi-GAN scale discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_sizes: List[int] = [15, 41, 5, 3], + channels: int = 128, + max_downsample_channels: int = 1024, + max_groups: int = 16, + bias: int = True, + downsample_scales: List[int] = [2, 2, 4, 4, 1], + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initilize HiFiGAN scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (List[int]): List of four kernel sizes. The first will be used + for the first conv layer, and the second is for downsampling part, and + the remaining two are for the last two output layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling + layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (List[int]): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. If set to true, it + will be applied to all of the conv layers. + + """ + super().__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 4 + for ks in kernel_sizes: + assert ks % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_channels, + channels, + # NOTE(kan-bayashi): Use always the same kernel size + kernel_sizes[0], + bias=bias, + padding=(kernel_sizes[0] - 1) // 2, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + out_chs = channels + # NOTE(kan-bayashi): Remove hard coding? + groups = 4 + for downsample_scale in downsample_scales: + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[1], + stride=downsample_scale, + padding=(kernel_sizes[1] - 1) // 2, + groups=groups, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Remove hard coding? + out_chs = min(in_chs * 2, max_downsample_channels) + # NOTE(kan-bayashi): Remove hard coding? + groups = min(groups * 4, max_groups) + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, + out_channels, + kernel_size=kernel_sizes[3], + stride=1, + padding=(kernel_sizes[3] - 1) // 2, + bias=bias, + ), + ] + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + self.use_weight_norm = use_weight_norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + self.use_spectral_norm = use_spectral_norm + if use_spectral_norm: + self.apply_spectral_norm() + + # backward compatibility + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[Tensor]: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def remove_spectral_norm(self): + """Remove spectral normalization module from all of the layers.""" + + def _remove_spectral_norm(m): + try: + logging.debug(f"Spectral norm is removed from {m}.") + torch.nn.utils.remove_spectral_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_spectral_norm) + + def _load_state_dict_pre_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """Fix the compatibility of weight / spectral normalization issue. + + Some pretrained models are trained with configs that use weight / spectral + normalization, but actually, the norm is not applied. This causes the mismatch + of the parameters with configs. To solve this issue, when parameter mismatch + happens in loading pretrained model, we remove the norm from the current model. + + See also: + - https://github.com/espnet/espnet/pull/5240 + - https://github.com/espnet/espnet/pull/5249 + - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409 + + """ + current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)] + if self.use_weight_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems weight norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_weight_norm() + self.use_weight_norm = False + for k in current_module_keys: + if k.endswith("weight_g") or k.endswith("weight_v"): + del state_dict[k] + + if self.use_spectral_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems spectral norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_spectral_norm() + self.use_spectral_norm = False + for k in current_module_keys: + if ( + k.endswith("weight_u") + or k.endswith("weight_v") + or k.endswith("weight_orig") + ): + del state_dict[k] + + +class HiFiGANMultiScaleDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale discriminator module.""" + + def __init__( + self, + scales: int = 3, + downsample_pooling: str = "AvgPool1d", + # follow the official implementation setting + downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = False, + ): + """Initilize HiFiGAN multi-scale discriminator module. + + Args: + scales (int): Number of multi-scales. + downsample_pooling (str): Pooling module name for downsampling of the + inputs. + downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling + module. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm + and the other discriminators use weight norm. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for i in range(scales): + params = copy.deepcopy(discriminator_params) + if follow_official_norm: + if i == 0: + params["use_weight_norm"] = False + params["use_spectral_norm"] = True + else: + params["use_weight_norm"] = True + params["use_spectral_norm"] = False + self.discriminators += [HiFiGANScaleDiscriminator(**params)] + self.pooling = None + if scales > 1: + self.pooling = getattr(torch.nn, downsample_pooling)( + **downsample_pooling_params + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[torch.Tensor]]: List of list of each discriminator outputs, + which consists of eachlayer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + if self.pooling is not None: + x = self.pooling(x) + + return outs + + +class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale + multi-period discriminator module.""" + + def __init__( + self, + # Multi-scale discriminator related + scales: int = 3, + scale_downsample_pooling: str = "AvgPool1d", + scale_downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + scale_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = True, + # Multi-period discriminator related + periods: List[int] = [2, 3, 5, 7, 11], + period_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initilize HiFiGAN multi-scale + multi-period discriminator module. + + Args: + scales (int): Number of multi-scales. + scale_downsample_pooling (str): Pooling module name for downsampling of the + inputs. + scale_downsample_pooling_params (dict): Parameters for the above pooling + module. + scale_discriminator_params (dict): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm and + the other discriminators use weight norm. + periods (list): List of periods. + period_discriminator_params (dict): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.msd = HiFiGANMultiScaleDiscriminator( + scales=scales, + downsample_pooling=scale_downsample_pooling, + downsample_pooling_params=scale_downsample_pooling_params, + discriminator_params=scale_discriminator_params, + follow_official_norm=follow_official_norm, + ) + self.mpd = HiFiGANMultiPeriodDiscriminator( + periods=periods, + discriminator_params=period_discriminator_params, + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[Tensor]]: List of list of each discriminator outputs, + which consists of each layer output tensors. Multi scale and + multi period ones are concatenated. + + """ + msd_outs = self.msd(x) + mpd_outs = self.mpd(x) + return msd_outs + mpd_outs diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py new file mode 100755 index 000000000..91a35e360 --- /dev/null +++ b/egs/ljspeech/TTS/vits/infer.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List + +import k2 +import torch +import torch.nn as nn +import torchaudio + +from train import get_model, get_params +from tokenizer import Tokenizer + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger +from tts_datamodule import LJSpeechTtsDataModule + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + # Background worker save audios to disk. + def _save_worker( + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), + audio[i:i + 1, :audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), + audio_pred[i:i + 1, :audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + + futures.append( + executor.submit( + _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to(device) + model.eval() + + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + + infer_dataset( + dl=test_dl, + params=params, + model=model, + tokenizer=tokenizer, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py new file mode 100644 index 000000000..21aaad6e7 --- /dev/null +++ b/egs/ljspeech/TTS/vits/loss.py @@ -0,0 +1,336 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFiGAN-related loss modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +from typing import List, Tuple, Union + +import torch +import torch.distributions as D +import torch.nn.functional as F + +from lhotse.features.kaldi import Wav2LogFilterBank + + +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize GeneratorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward( + self, + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Calcualate generator adversarial loss. + + Args: + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs.. + + Returns: + Tensor: Generator adversarial loss value. + + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize DiscriminatorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward( + self, + outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from generator. + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) + + def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) + + +class FeatureMatchLoss(torch.nn.Module): + """Feature matching loss module.""" + + def __init__( + self, + average_by_layers: bool = True, + average_by_discriminators: bool = True, + include_final_outputs: bool = False, + ): + """Initialize FeatureMatchLoss module. + + Args: + average_by_layers (bool): Whether to average the loss by the number + of layers. + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + include_final_outputs (bool): Whether to include the final output of + each discriminator for loss calculation. + + """ + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward( + self, + feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], + feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], + ) -> torch.Tensor: + """Calculate feature matching loss. + + Args: + feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from generator's outputs. + feats (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from groundtruth.. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss + + +class MelSpectrogramLoss(torch.nn.Module): + """Mel-spectrogram loss.""" + + def __init__( + self, + sampling_rate: int = 22050, + frame_length: int = 1024, # in samples + frame_shift: int = 256, # in samples + n_mels: int = 80, + use_fft_mag: bool = True, + ): + super().__init__() + self.wav_to_mel = Wav2LogFilterBank( + sampling_rate=sampling_rate, + frame_length=frame_length / sampling_rate, # in second + frame_shift=frame_shift / sampling_rate, # in second + use_fft_mag=use_fft_mag, + num_filters=n_mels, + ) + + def forward( + self, + y_hat: torch.Tensor, + y: torch.Tensor, + return_mel: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Calculate Mel-spectrogram loss. + + Args: + y_hat (Tensor): Generated waveform tensor (B, 1, T). + y (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_hat = self.wav_to_mel(y_hat.squeeze(1)) + mel = self.wav_to_mel(y.squeeze(1)) + mel_loss = F.l1_loss(mel_hat, mel) + + if return_mel: + return mel_loss, (mel_hat, mel) + + return mel_loss + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py + +"""VITS-related loss modules. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +class KLDivergenceLoss(torch.nn.Module): + """KL divergence loss.""" + + def forward( + self, + z_p: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + z_mask: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss. + + Args: + z_p (Tensor): Flow hidden representation (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + z_mask (Tensor): Mask tensor (B, 1, T_feats). + + Returns: + Tensor: KL divergence loss. + + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + loss = kl / torch.sum(z_mask) + + return loss + + +class KLDivergenceLossWithoutFlow(torch.nn.Module): + """KL divergence loss without flow.""" + + def forward( + self, + m_q: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss without flow. + + Args: + m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + """ + posterior_norm = D.Normal(m_q, torch.exp(logs_q)) + prior_norm = D.Normal(m_p, torch.exp(logs_p)) + loss = D.kl_divergence(posterior_norm, prior_norm).mean() + return loss diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py new file mode 100644 index 000000000..2b35654f5 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py @@ -0,0 +1,81 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py + +"""Maximum path calculation module. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import warnings + +import numpy as np +import torch +from numba import njit, prange + +try: + from .core import maximum_path_c + + is_cython_avalable = True +except ImportError: + is_cython_avalable = False + warnings.warn( + "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " + "If you want to use the cython version, please build it as follows: " + "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`" + ) + + +def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """Calculate maximum path. + + Args: + neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). + attn_mask (Tensor): Attention mask (B, T_feats, T_text). + + Returns: + Tensor: Maximum path tensor (B, T_feats, T_text). + + """ + device, dtype = neg_x_ent.device, neg_x_ent.dtype + neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) + path = np.zeros(neg_x_ent.shape, dtype=np.int32) + t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) + t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) + if is_cython_avalable: + maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) + else: + maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) + + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +@njit +def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): + """Calculate a single maximum path with numba.""" + index = t_x - 1 + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@njit(parallel=True) +def maximum_path_numba(paths, values, t_ys, t_xs): + """Calculate batch maximum path with numba.""" + for i in prange(paths.shape[0]): + maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/core.pyx b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx new file mode 100644 index 000000000..c02c2d02e --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx @@ -0,0 +1,51 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx + +"""Maximum path calculation module with cython optimization. + +This code is copied from https://github.com/jaywalnut310/vits and modifed code format. + +""" + +cimport cython + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py new file mode 100644 index 000000000..33d75e176 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/setup.py @@ -0,0 +1,31 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py +"""Setup cython code.""" + +from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext + + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + + +exts = [ + Extension( + name="core", + sources=["core.pyx"], + ) +] +setup( + name="monotonic_align", + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, +) diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py new file mode 100644 index 000000000..6b8a5be52 --- /dev/null +++ b/egs/ljspeech/TTS/vits/posterior_encoder.py @@ -0,0 +1,117 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Posterior encoder module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple + +import torch + +from icefall.utils import make_pad_mask +from wavenet import WaveNet, Conv1d + + +class PosteriorEncoder(torch.nn.Module): + """Posterior encoder module in VITS. + + This is a module of posterior encoder described in `Conditional Variational + Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + """ + + def __init__( + self, + in_channels: int = 513, + out_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + layers: int = 16, + stacks: int = 1, + base_dilation: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + ): + """Initilialize PosteriorEncoder module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size in WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of repeat stacking of WaveNet. + base_dilation (int): Base dilation factor. + global_channels (int): Number of global conditioning channels. + dropout_rate (float): Dropout rate. + bias (bool): Whether to use bias parameters in conv. + use_weight_norm (bool): Whether to apply weight norm. + + """ + super().__init__() + + # define modules + self.input_conv = Conv1d(in_channels, hidden_channels, 1) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + self.proj = Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_feats). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Encoded hidden representation tensor (B, out_channels, T_feats). + Tensor: Projected mean tensor (B, out_channels, T_feats). + Tensor: Projected scale tensor (B, out_channels, T_feats). + Tensor: Mask tensor for input tensor (B, 1, T_feats). + + """ + x_mask = ( + (~make_pad_mask(x_lengths)) + .unsqueeze(1) + .to( + dtype=x.dtype, + device=x.device, + ) + ) + x = self.input_conv(x) * x_mask + x = self.encoder(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + + return z, m, logs, x_mask diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py new file mode 100644 index 000000000..2d6807cb7 --- /dev/null +++ b/egs/ljspeech/TTS/vits/residual_coupling.py @@ -0,0 +1,229 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Residual affine coupling modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple, Union + +import torch + +from flow import FlipFlow +from wavenet import WaveNet + + +class ResidualAffineCouplingBlock(torch.nn.Module): + """Residual affine coupling block module. + + This is a module of residual affine coupling block, which used as "Flow" in + `Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + flows: int = 4, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 4, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initilize ResidualAffineCouplingBlock module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + flows (int): Number of flows. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + super().__init__() + + self.flows = torch.nn.ModuleList() + for i in range(flows): + self.flows += [ + ResidualAffineCouplingLayer( + in_channels=in_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + base_dilation=base_dilation, + layers=layers, + stacks=1, + global_channels=global_channels, + dropout_rate=dropout_rate, + use_weight_norm=use_weight_norm, + bias=bias, + use_only_mean=use_only_mean, + ) + ] + self.flows += [FlipFlow()] + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + + """ + if not inverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, inverse=inverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, inverse=inverse) + return x + + +class ResidualAffineCouplingLayer(torch.nn.Module): + """Residual affine coupling layer.""" + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 5, + stacks: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initialzie ResidualAffineCouplingLayer module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + assert in_channels % 2 == 0, "in_channels should be divisible by 2" + super().__init__() + self.half_channels = in_channels // 2 + self.use_only_mean = use_only_mean + + # define modules + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + if use_only_mean: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels, + 1, + ) + else: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * 2, + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, dim=1) + h = self.input_conv(xa) * x_mask + h = self.encoder(h, x_mask, g=g) + stats = self.proj(h) * x_mask + if not self.use_only_mean: + m, logs = stats.split(stats.size(1) // 2, dim=1) + else: + m = stats + logs = torch.zeros_like(m) + + if not inverse: + xb = m + xb * torch.exp(logs) * x_mask + x = torch.cat([xa, xb], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + xb = (xb - m) * torch.exp(-logs) * x_mask + x = torch.cat([xa, xb], 1) + return x diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py new file mode 100755 index 000000000..8acca7c02 --- /dev/null +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +import onnxruntime as ort +import torch +import torchaudio + +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: noise_scale_dur.numpy(), + self.model.get_inputs()[4].name: alpha.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + audio = model(tokens, tokens_lens) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py new file mode 100644 index 000000000..9f337e45b --- /dev/null +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -0,0 +1,662 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text encoder module in VITS. + +This code is based on + - https://github.com/jaywalnut310/vits + - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py + - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py +""" + +import copy +import math +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn + +from icefall.utils import is_jit_tracing, make_pad_mask + + +class TextEncoder(torch.nn.Module): + """Text encoder module in VITS. + + This is a module of text encoder described in `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + cnn_module_kernel: int = 5, + num_layers: int = 6, + dropout: float = 0.1, + ): + """Initialize TextEncoder module. + + Args: + vocabs (int): Vocabulary size. + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + super().__init__() + self.d_model = d_model + + # define modules + self.emb = torch.nn.Embedding(vocabs, d_model) + torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5) + + # We use conformer as text encoder + self.encoder = Transformer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, + num_layers=num_layers, + dropout=dropout, + ) + + self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input index tensor (B, T_text). + x_lengths (Tensor): Length tensor (B,). + + Returns: + Tensor: Encoded hidden representation (B, attention_dim, T_text). + Tensor: Projected mean tensor (B, attention_dim, T_text). + Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Mask tensor for input tensor (B, 1, T_text). + + """ + # (B, T_text, embed_dim) + x = self.emb(x) * math.sqrt(self.d_model) + + assert x.size(1) == x_lengths.max().item() + + # (B, T_text) + pad_mask = make_pad_mask(x_lengths) + + # encoder assume the channel last (B, T_text, embed_dim) + x = self.encoder(x, key_padding_mask=pad_mask) + + # convert the channel first (B, embed_dim, T_text) + x = x.transpose(1, 2) + non_pad_mask = (~pad_mask).unsqueeze(1) + stats = self.proj(x) * non_pad_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + + return x, m, logs, non_pad_mask + + +class Transformer(nn.Module): + """ + Args: + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + cnn_module_kernel: int = 5, + num_layers: int = 6, + dropout: float = 0.1, + ) -> None: + super().__init__() + + self.num_layers = num_layers + self.d_model = d_model + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, + dropout=dropout, + ) + self.encoder = TransformerEncoder(encoder_layer, num_layers) + self.after_norm = nn.LayerNorm(d_model) + + def forward( + self, x: Tensor, key_padding_mask: Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + lengths: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + """ + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x = self.encoder( + x, pos_emb, key_padding_mask=key_padding_mask + ) # (T, N, C) + + x = self.after_norm(x) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x + + +class TransformerEncoderLayer(nn.Module): + """ + TransformerEncoderLayer is made up of self-attn and feedforward. + + Args: + d_model: the number of expected features in the input. + num_heads: the number of heads in the multi-head attention models. + dim_feedforward: the dimension of the feed-forward network model. + dropout: the dropout value (default=0.1). + """ + + def __init__( + self, + d_model: int, + num_heads: int, + dim_feedforward: int, + cnn_module_kernel: int, + dropout: float = 0.1, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + + self.ff_scale = 0.5 + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the transformer encoder layer. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + # macaron style feed-forward module + src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + + # multi-head self-attention module + src_attn = self.self_attn( + self.norm_mha(src), + pos_emb=pos_emb, + key_padding_mask=key_padding_mask, + ) + src = src + self.dropout(src_attn) + + # convolution module + src = src + self.dropout(self.conv_module(self.norm_conv(src))) + + # feed-forward module + src = src + self.dropout(self.feed_forward(self.norm_ff(src))) + + src = self.norm_final(src) + + return src + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer class. + num_layers: the number of sub-encoder-layers in the encoder. + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + key_padding_mask=key_padding_mask, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + x_size = x.size(1) + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, seq_len, 2*seq_len-1). + + Returns: + Tensor: tensor of shape (batch, head, seq_len, seq_len) + """ + (batch_size, num_heads, seq_len, n) = x.shape + + if not is_jit_tracing(): + assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" + + if is_jit_tracing(): + rows = torch.arange(start=seq_len - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, seq_len, seq_len) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, seq_len, seq_len), + (batch_stride, head_stride, time_stride - n_stride, n_stride), + storage_offset=n_stride * (seq_len - 1), + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: Input tensor of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim) + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + Its shape is (batch_size, seq_len). + + Outputs: + A tensor of shape (seq_len, batch_size, embed_dim). + """ + seq_len, batch_size, _ = x.shape + scaling = float(self.head_dim) ** -0.5 + + q, k, v = self.in_proj(x).chunk(3, dim=-1) + + q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + + q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) + + p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) + p = p.permute(0, 2, 3, 1) + + # (batch_size, num_head, seq_len, head_dim) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + + # compute matrix b and matrix d + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + + # (batch_size, num_head, seq_len, seq_len) + attn_output_weights = (matrix_ac + matrix_bd) * scaling + attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size, self.num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + # (batch_size * num_head, seq_len, head_dim) + attn_output = torch.bmm(attn_output_weights, v) + assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + ) + # (seq_len, batch_size, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def _test_text_encoder(): + vocabs = 500 + d_model = 192 + batch_size = 5 + seq_len = 100 + + m = TextEncoder(vocabs=vocabs, d_model=d_model) + x, m, logs, mask = m( + x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)), + x_lengths=torch.full((batch_size,), seq_len), + ) + print(x.shape, m.shape, logs.shape, mask.shape) + + +if __name__ == "__main__": + _test_text_encoder() diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py new file mode 100644 index 000000000..0678b26fe --- /dev/null +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -0,0 +1,106 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import g2p_en +import tacotron_cleaner.cleaners +from utils import intersperse + + +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + self.token2id[token] = id + + self.blank_id = self.token2id[""] + self.oov_id = self.token2id[""] + self.vocab_size = len(self.token2id) + + self.g2p = g2p_en.G2p() + + def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): + """ + Args: + texts: + A list of transcripts. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for text in texts: + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens = self.g2p(text) + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + + token_ids_list.append(token_ids) + + return token_ids_list + + def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for tokens in tokens_list: + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py new file mode 100755 index 000000000..eb43a4cc9 --- /dev/null +++ b/egs/ljspeech/TTS/vits/train.py @@ -0,0 +1,893 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import numpy as np +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.optim import Optimizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +from tokenizer import Tokenizer +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 22050, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_image( + "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + ) + tb_writer.add_image( + "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + ) + + if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, :tokens_lens[0].item()] + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) + audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/transform.py b/egs/ljspeech/TTS/vits/transform.py new file mode 100644 index 000000000..c20d13130 --- /dev/null +++ b/egs/ljspeech/TTS/vits/transform.py @@ -0,0 +1,218 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py + +"""Flow-related transformation. + +This code is derived from https://github.com/bayesiains/nflows. + +""" + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +# TODO(kan-bayashi): Documentation and type hint +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = _searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = _searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet + + +def _searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py new file mode 100644 index 000000000..0fcbb92c1 --- /dev/null +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -0,0 +1,325 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + SpeechSynthesisDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py new file mode 100644 index 000000000..2a3dae900 --- /dev/null +++ b/egs/ljspeech/TTS/vits/utils.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Tuple, Union +import collections +import logging + +import torch +import torch.nn as nn +import torch.distributed as dist +from lhotse.dataset.sampling.base import CutSampler +from pathlib import Path +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py +def get_random_segments( + x: torch.Tensor, + x_lengths: torch.Tensor, + segment_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get random segments. + + Args: + x (Tensor): Input tensor (B, C, T). + x_lengths (Tensor): Length tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + Tensor: Start index tensor (B,). + + """ + b, c, t = x.size() + max_start_idx = x_lengths - segment_size + max_start_idx[max_start_idx < 0] = 0 + start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( + dtype=torch.long, + ) + segments = get_segments(x, start_idxs, segment_size) + + return segments, start_idxs + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py +def get_segments( + x: torch.Tensor, + start_idxs: torch.Tensor, + segment_size: int, +) -> torch.Tensor: + """Get segments. + + Args: + x (Tensor): Input tensor (B, C, T). + start_idxs (Tensor): Start index tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + + """ + b, c, t = x.size() + segments = x.new_zeros(b, c, segment_size) + for i, start_idx in enumerate(start_idxs): + segments[i] = x[i, :, start_idx : start_idx + segment_size] + return segments + + +# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py +def intersperse(sequence, item=0): + result = [item] * (len(sequence) * 2 + 1) + result[1::2] = sequence + return result + + +# from https://github.com/jaywalnut310/vits/blob/main/utils.py +MATPLOTLIB_FLAG = False + + +def plot_feature(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + samples = "%.2f" % self["samples"] + ans += "over " + str(samples) + " samples." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('loss_1', 0.1), ('loss_2', 0.07)] + """ + samples = self["samples"] if "samples" in self else 1 + ans = [] + for k, v in self.items(): + if k == "samples": + continue + norm_value = float(v) / samples + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + +# checkpoint saving and loading +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRSchedulerType] = None, + scheduler_d: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + params: + User defined parameters, e.g., epoch, loss. + optimizer_g: + The optimizer for generator used in the training. + Its `state_dict` will be saved. + optimizer_d: + The optimizer for discriminator used in the training. + Its `state_dict` will be saved. + scheduler_g: + The learning rate scheduler for generator used in the training. + Its `state_dict` will be saved. + scheduler_d: + The learning rate scheduler for discriminator used in the training. + Its `state_dict` will be saved. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, + "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py new file mode 100644 index 000000000..d5e20a578 --- /dev/null +++ b/egs/ljspeech/TTS/vits/vits.py @@ -0,0 +1,610 @@ +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""VITS module for GAN-TTS task.""" + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from torch.cuda.amp import autocast + +from hifigan import ( + HiFiGANMultiPeriodDiscriminator, + HiFiGANMultiScaleDiscriminator, + HiFiGANMultiScaleMultiPeriodDiscriminator, + HiFiGANPeriodDiscriminator, + HiFiGANScaleDiscriminator, +) +from loss import ( + DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + KLDivergenceLoss, + MelSpectrogramLoss, +) +from utils import get_segments +from generator import VITSGenerator + + +AVAILABLE_GENERATERS = { + "vits_generator": VITSGenerator, +} +AVAILABLE_DISCRIMINATORS = { + "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, + "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, + "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, + "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, + "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA +} + + +class VITS(nn.Module): + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` + """ + + def __init__( + self, + # generator related + vocab_size: int, + feature_dim: int = 513, + sampling_rate: int = 22050, + generator_type: str = "vits_generator", + generator_params: Dict[str, Any] = { + "hidden_channels": 192, + "spks": None, + "langs": None, + "spk_embed_dim": None, + "global_channels": -1, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + }, + # discriminator related + discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", + discriminator_params: Dict[str, Any] = { + "scales": 1, + "scale_downsample_pooling": "AvgPool1d", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + "follow_official_norm": False, + "periods": [2, 3, 5, 7, 11], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + # loss related + generator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + discriminator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + feat_match_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "average_by_layers": False, + "include_final_outputs": True, + }, + mel_loss_params: Dict[str, Any] = { + "frame_shift": 256, + "frame_length": 1024, + "n_mels": 80, + }, + lambda_adv: float = 1.0, + lambda_mel: float = 45.0, + lambda_feat_match: float = 2.0, + lambda_dur: float = 1.0, + lambda_kl: float = 1.0, + cache_generator_outputs: bool = True, + ): + """Initialize VITS module. + + Args: + idim (int): Input vocabrary size. + odim (int): Acoustic feature dimension. The actual output channels will + be 1 since VITS is the end-to-end text-to-wave model but for the + compatibility odim is used to indicate the acoustic feature dimension. + sampling_rate (int): Sampling rate, not used for the training but it will + be referred in saving waveform during the inference. + generator_type (str): Generator type. + generator_params (Dict[str, Any]): Parameter dict for generator. + discriminator_type (str): Discriminator type. + discriminator_params (Dict[str, Any]): Parameter dict for discriminator. + generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator + adversarial loss. + discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for + discriminator adversarial loss. + feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. + mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. + lambda_adv (float): Loss scaling coefficient for adversarial loss. + lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. + lambda_feat_match (float): Loss scaling coefficient for feat match loss. + lambda_dur (float): Loss scaling coefficient for duration loss. + lambda_kl (float): Loss scaling coefficient for KL divergence loss. + cache_generator_outputs (bool): Whether to cache generator outputs. + + """ + super().__init__() + + # define modules + generator_class = AVAILABLE_GENERATERS[generator_type] + if generator_type == "vits_generator": + # NOTE(kan-bayashi): Update parameters for the compatibility. + # The idim and odim is automatically decided from input data, + # where idim represents #vocabularies and odim represents + # the input acoustic feature dimension. + generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) + self.generator = generator_class( + **generator_params, + ) + discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] + self.discriminator = discriminator_class( + **discriminator_params, + ) + self.generator_adv_loss = GeneratorAdversarialLoss( + **generator_adv_loss_params, + ) + self.discriminator_adv_loss = DiscriminatorAdversarialLoss( + **discriminator_adv_loss_params, + ) + self.feat_match_loss = FeatureMatchLoss( + **feat_match_loss_params, + ) + mel_loss_params.update(sampling_rate=sampling_rate) + self.mel_loss = MelSpectrogramLoss( + **mel_loss_params, + ) + self.kl_loss = KLDivergenceLoss() + + # coefficients + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_kl = lambda_kl + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # store sampling rate for saving wav file + # (not used for the training) + self.sampling_rate = sampling_rate + + # store parameters for test compatibility + self.spks = self.generator.spks + self.langs = self.generator.langs + self.spk_embed_dim = self.generator.spk_embed_dim + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + return_sample: bool = False, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + forward_generator: bool = True, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + forward_generator (bool): Whether to forward generator. + + Returns: + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. + """ + if forward_generator: + return self._forward_generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + return_sample=return_sample, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + return self._forward_discrminator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + + def _forward_generator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + return_sample: bool = False, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + if not return_sample: + mel_loss = self.mel_loss(speech_hat_, speech_) + else: + mel_loss, (mel_hat_, mel_) = self.mel_loss( + speech_hat_, speech_, return_mel=True + ) + kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = torch.sum(dur_nll.float()) + adv_loss = self.generator_adv_loss(p_hat) + feat_match_loss = self.feat_match_loss(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + stats = dict( + generator_loss=loss.item(), + generator_mel_loss=mel_loss.item(), + generator_kl_loss=kl_loss.item(), + generator_dur_loss=dur_loss.item(), + generator_adv_loss=adv_loss.item(), + generator_feat_match_loss=feat_match_loss.item(), + ) + + if return_sample: + stats["returned_sample"] = ( + speech_hat_[0].data.cpu().numpy(), + speech_[0].data.cpu().numpy(), + mel_hat_[0].data.cpu().numpy(), + mel_[0].data.cpu().numpy(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def _forward_discrminator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform discriminator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_.detach()) + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) + loss = real_loss + fake_loss + + stats = dict( + discriminator_loss=loss.item(), + discriminator_real_loss=real_loss.item(), + discriminator_fake_loss=fake_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def inference( + self, + text: torch.Tensor, + feats: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for single sample. + + Args: + text (Tensor): Input text index tensor (T_text,). + feats (Tensor): Feature tensor (T_feats, aux_channels). + sids (Tensor): Speaker index tensor (1,). + spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). + lids (Tensor): Language index tensor (1,). + durations (Tensor): Ground-truth duration tensor (T_text,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + * wav (Tensor): Generated waveform tensor (T_wav,). + * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). + * duration (Tensor): Predicted duration tensor (T_text,). + """ + # setup + text = text[None] + text_lengths = torch.tensor( + [text.size(1)], + dtype=torch.long, + device=text.device, + ) + if sids is not None: + sids = sids.view(1) + if lids is not None: + lids = lids.view(1) + if durations is not None: + durations = durations.view(1, 1, -1) + + # inference + if use_teacher_forcing: + assert feats is not None + feats = feats[None].transpose(1, 2) + feats_lengths = torch.tensor( + [feats.size(2)], + dtype=torch.long, + device=feats.device, + ) + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + max_len=max_len, + use_teacher_forcing=use_teacher_forcing, + ) + else: + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + spembs=spembs, + lids=lids, + dur=durations, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return wav.view(-1), att_w[0], dur[0] + + def inference_batch( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for one batch. + + Args: + text (Tensor): Input text index tensor (B, T_text). + text_lengths (Tensor): Input text index tensor (B,). + sids (Tensor): Speaker index tensor (B,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + + Returns: + * wav (Tensor): Generated waveform tensor (B, T_wav). + * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). + * duration (Tensor): Predicted duration tensor (B, T_text). + """ + # inference + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return wav, att_w, dur diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py new file mode 100644 index 000000000..fbe1be52b --- /dev/null +++ b/egs/ljspeech/TTS/vits/wavenet.py @@ -0,0 +1,349 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""WaveNet modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import math +import logging + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +class WaveNet(torch.nn.Module): + """WaveNet with global conditioning.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_size: int = 3, + layers: int = 30, + stacks: int = 3, + base_dilation: int = 2, + residual_channels: int = 64, + aux_channels: int = -1, + gate_channels: int = 128, + skip_channels: int = 64, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + use_first_conv: bool = False, + use_last_conv: bool = False, + scale_residual: bool = False, + scale_skip_connect: bool = False, + ): + """Initialize WaveNet module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + base_dilation (int): Base dilation factor. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + aux_channels (int): Number of channels for local conditioning feature. + global_channels (int): Number of channels for global conditioning feature. + dropout_rate (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv layer. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_first_conv (bool): Whether to use the first conv layers. + use_last_conv (bool): Whether to use the last conv layers. + scale_residual (bool): Whether to scale the residual outputs. + scale_skip_connect (bool): Whether to scale the skip connection outputs. + + """ + super().__init__() + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + self.base_dilation = base_dilation + self.use_first_conv = use_first_conv + self.use_last_conv = use_last_conv + self.scale_skip_connect = scale_skip_connect + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + if self.use_first_conv: + self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = base_dilation ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + global_channels=global_channels, + dilation=dilation, + dropout_rate=dropout_rate, + bias=bias, + scale_residual=scale_residual, + ) + self.conv_layers += [conv] + + # define output layers + if self.use_last_conv: + self.last_conv = torch.nn.Sequential( + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, skip_channels, bias=True), + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, out_channels, bias=True), + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T) if use_first_conv else + (B, residual_channels, T). + x_mask (Optional[Tensor]): Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T) if use_last_conv else + (B, residual_channels, T). + + """ + # encode to hidden representation + if self.use_first_conv: + x = self.first_conv(x) + + # residual block + skips = 0.0 + for f in self.conv_layers: + x, h = f(x, x_mask=x_mask, c=c, g=g) + skips = skips + h + x = skips + if self.scale_skip_connect: + x = x * math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + if self.use_last_conv: + x = self.last_conv(x) + + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size( + layers: int, + stacks: int, + kernel_size: int, + base_dilation: int, + ) -> int: + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self) -> int: + """Return receptive field size.""" + return self._get_receptive_field_size( + self.layers, self.stacks, self.kernel_size, self.base_dilation + ) + + +class Conv1d(torch.nn.Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super().__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels: int, out_channels: int, bias: bool): + """Initialize 1x1 Conv1d module.""" + super().__init__( + in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias + ) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size: int = 3, + residual_channels: int = 64, + gate_channels: int = 128, + skip_channels: int = 64, + aux_channels: int = 80, + global_channels: int = -1, + dropout_rate: float = 0.0, + dilation: int = 1, + bias: bool = True, + scale_residual: bool = False, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Number of local conditioning channels. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + scale_residual (bool): Whether to scale the residual outputs. + + """ + super().__init__() + self.dropout_rate = dropout_rate + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.scale_residual = scale_residual + + # check + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + assert gate_channels % 2 == 0 + + # dilation conv + padding = (kernel_size - 1) // 2 * dilation + self.conv = Conv1d( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) + else: + self.conv1x1_aux = None + + # global conditioning + if global_channels > 0: + self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False) + else: + self.conv1x1_glo = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + + # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency + # (integrate res 1x1 + skip 1x1 convs) + self.conv1x1_out = Conv1d1x1( + gate_out_channels, residual_channels + skip_channels, bias=bias + ) + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout_rate, training=self.training) + x = self.conv(x) + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + # global conditioning + if g is not None: + g = self.conv1x1_glo(g) + ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ga, xb + gb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # residual + skip 1x1 conv + x = self.conv1x1_out(x) + if x_mask is not None: + x = x * x_mask + + # split integrated conv results + x, s = x.split([self.residual_channels, self.skip_channels], dim=1) + + # for residual connection + x = x + residual + if self.scale_residual: + x = x * math.sqrt(0.5) + + return x, s diff --git a/pyproject.toml b/pyproject.toml index c40143fb9..435256416 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,4 +14,5 @@ exclude = ''' | icefall\/diagnostics\.py | icefall\/profiler\.py | egs\/librispeech\/ASR\/zipformer + | egs\/ljspeech\/TTS\/vits ''' From f08af2fa2217e226394b6f03442952d104bd984e Mon Sep 17 00:00:00 2001 From: LoganLiu66 <2319277867@qq.com> Date: Mon, 4 Dec 2023 22:29:42 +0800 Subject: [PATCH 053/216] fix initial states (#1398) Co-authored-by: liujiawang02 --- .../pruned_transducer_stateless7_streaming/decode_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py index 0d7e86fcf..2c4b144fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -82,12 +82,12 @@ class DecodeStream(object): self.pad_length = 7 if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] elif params.decoding_method == "modified_beam_search": self.hyps = HypothesisList() self.hyps.add( Hypothesis( - ys=[params.blank_id] * params.context_size, + ys=[-1] * (params.context_size - 1) + [params.blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) From 735fb9a73dea7d27e95056add6598ae7a282d6f9 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 6 Dec 2023 09:59:19 +0800 Subject: [PATCH 054/216] A TTS recipe VITS on VCTK dataset (#1380) * init * isort formatted * minor updates * Create shared * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare.sh * updated * Update train.py * Update train.py * Update tts_datamodule.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * fixed formatting issue * Update infer.py * removed redundant files * Create monotonic_align * removed redundant files * created symlinks * Update prepare.sh * minor adjustments * Create requirements_tts.txt * Update requirements_tts.txt added version constraints * Update infer.py * Update infer.py * Update infer.py * updated docs * Update export-onnx.py * Update export-onnx.py * Update test_onnx.py * updated requirements.txt * Update test_onnx.py * Update test_onnx.py * docs updated * docs fixed * minor updates --- docs/source/recipes/TTS/index.rst | 1 + docs/source/recipes/TTS/ljspeech/vits.rst | 12 +- docs/source/recipes/TTS/vctk/vits.rst | 125 +++ egs/ljspeech/TTS/prepare.sh | 16 +- egs/ljspeech/TTS/vits/duration_predictor.py | 1 - egs/ljspeech/TTS/vits/export-onnx.py | 8 +- egs/ljspeech/TTS/vits/flow.py | 1 - egs/ljspeech/TTS/vits/generator.py | 5 +- egs/ljspeech/TTS/vits/infer.py | 29 +- egs/ljspeech/TTS/vits/loss.py | 1 - egs/ljspeech/TTS/vits/posterior_encoder.py | 2 +- egs/ljspeech/TTS/vits/residual_coupling.py | 1 - egs/ljspeech/TTS/vits/test_onnx.py | 2 +- egs/ljspeech/TTS/vits/text_encoder.py | 48 +- egs/ljspeech/TTS/vits/tokenizer.py | 4 +- egs/ljspeech/TTS/vits/train.py | 93 +- egs/ljspeech/TTS/vits/tts_datamodule.py | 2 +- egs/ljspeech/TTS/vits/utils.py | 14 +- egs/ljspeech/TTS/vits/vits.py | 9 +- egs/ljspeech/TTS/vits/wavenet.py | 3 +- .../TTS/local/compute_spectrogram_vctk.py | 107 ++ .../TTS/local/display_manifest_statistics.py | 83 ++ egs/vctk/TTS/local/prepare_token_file.py | 104 ++ egs/vctk/TTS/local/prepare_tokens_vctk.py | 61 + egs/vctk/TTS/local/validate_manifest.py | 70 ++ egs/vctk/TTS/prepare.sh | 131 +++ egs/vctk/TTS/shared | 1 + egs/vctk/TTS/vits/duration_predictor.py | 1 + egs/vctk/TTS/vits/export-onnx.py | 284 +++++ egs/vctk/TTS/vits/flow.py | 1 + egs/vctk/TTS/vits/generator.py | 1 + egs/vctk/TTS/vits/hifigan.py | 1 + egs/vctk/TTS/vits/infer.py | 272 +++++ egs/vctk/TTS/vits/loss.py | 1 + egs/vctk/TTS/vits/monotonic_align | 1 + egs/vctk/TTS/vits/posterior_encoder.py | 1 + egs/vctk/TTS/vits/residual_coupling.py | 1 + egs/vctk/TTS/vits/test_onnx.py | 138 +++ egs/vctk/TTS/vits/text_encoder.py | 1 + egs/vctk/TTS/vits/tokenizer.py | 1 + egs/vctk/TTS/vits/train.py | 1000 +++++++++++++++++ egs/vctk/TTS/vits/transform.py | 1 + egs/vctk/TTS/vits/tts_datamodule.py | 338 ++++++ egs/vctk/TTS/vits/utils.py | 1 + egs/vctk/TTS/vits/vits.py | 1 + egs/vctk/TTS/vits/wavenet.py | 1 + requirements-tts.txt | 6 + requirements.txt | 2 + 48 files changed, 2904 insertions(+), 84 deletions(-) create mode 100644 docs/source/recipes/TTS/vctk/vits.rst create mode 100755 egs/vctk/TTS/local/compute_spectrogram_vctk.py create mode 100755 egs/vctk/TTS/local/display_manifest_statistics.py create mode 100755 egs/vctk/TTS/local/prepare_token_file.py create mode 100755 egs/vctk/TTS/local/prepare_tokens_vctk.py create mode 100755 egs/vctk/TTS/local/validate_manifest.py create mode 100755 egs/vctk/TTS/prepare.sh create mode 120000 egs/vctk/TTS/shared create mode 120000 egs/vctk/TTS/vits/duration_predictor.py create mode 100755 egs/vctk/TTS/vits/export-onnx.py create mode 120000 egs/vctk/TTS/vits/flow.py create mode 120000 egs/vctk/TTS/vits/generator.py create mode 120000 egs/vctk/TTS/vits/hifigan.py create mode 100755 egs/vctk/TTS/vits/infer.py create mode 120000 egs/vctk/TTS/vits/loss.py create mode 120000 egs/vctk/TTS/vits/monotonic_align create mode 120000 egs/vctk/TTS/vits/posterior_encoder.py create mode 120000 egs/vctk/TTS/vits/residual_coupling.py create mode 100755 egs/vctk/TTS/vits/test_onnx.py create mode 120000 egs/vctk/TTS/vits/text_encoder.py create mode 120000 egs/vctk/TTS/vits/tokenizer.py create mode 100755 egs/vctk/TTS/vits/train.py create mode 120000 egs/vctk/TTS/vits/transform.py create mode 100644 egs/vctk/TTS/vits/tts_datamodule.py create mode 120000 egs/vctk/TTS/vits/utils.py create mode 120000 egs/vctk/TTS/vits/vits.py create mode 120000 egs/vctk/TTS/vits/wavenet.py create mode 100644 requirements-tts.txt diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst index aa891c072..80d67a2f3 100644 --- a/docs/source/recipes/TTS/index.rst +++ b/docs/source/recipes/TTS/index.rst @@ -5,3 +5,4 @@ TTS :maxdepth: 2 ljspeech/vits + vctk/vits \ No newline at end of file diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 385fd3c70..d08aa0f47 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -4,6 +4,10 @@ VITS This tutorial shows you how to train an VITS model with the `LJSpeech `_ dataset. +.. note:: + + TTS related recipes require packages in ``requirements-tts.txt``. + .. note:: The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ @@ -27,6 +31,12 @@ To run stage 1 to stage 5, use Build Monotonic Alignment Search -------------------------------- +.. code-block:: bash + + $ ./prepare.sh --stage -1 --stop_stage -1 + +or + .. code-block:: bash $ cd vits/monotonic_align @@ -74,7 +84,7 @@ training part first. It will save the ground-truth and generated wavs to the dir $ ./vits/infer.py \ --epoch 1000 \ --exp-dir vits/exp \ - --tokens data/tokens.txt + --tokens data/tokens.txt \ --max-duration 500 .. note:: diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst new file mode 100644 index 000000000..34024a5ea --- /dev/null +++ b/docs/source/recipes/TTS/vctk/vits.rst @@ -0,0 +1,125 @@ +VITS +=============== + +This tutorial shows you how to train an VITS model +with the `VCTK `_ dataset. + +.. note:: + + TTS related recipes require packages in ``requirements-tts.txt``. + +.. note:: + + The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/vctk/TTS + $ ./prepare.sh + +To run stage 1 to stage 6, use + +.. code-block:: bash + + $ ./prepare.sh --stage 1 --stop_stage 6 + + +Build Monotonic Alignment Search +-------------------------------- + +To build the monotonic alignment search, use the following commands: + +.. code-block:: bash + + $ ./prepare.sh --stage -1 --stop_stage -1 + +or + +.. code-block:: bash + + $ cd vits/monotonic_align + $ python setup.py build_ext --inplace + $ cd ../../ + + +Training +-------- + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0,1,2,3" + $ ./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 350 + +.. note:: + + You can adjust the hyper-parameters to control the size of the VITS model and + the training configurations. For more details, please run ``./vits/train.py --help``. + +.. note:: + + The training can take a long time (usually a couple of days). + +Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``. + + +Inference +--------- + +The inference part uses checkpoints saved by the training part, so you have to run the +training part first. It will save the ground-truth and generated wavs to the directory +``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``. + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0" + $ ./vits/infer.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt \ + --max-duration 500 + +.. note:: + + For more details, please run ``./vits/infer.py --help``. + + +Export models +------------- + +Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: +``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. + +.. code-block:: bash + + $ ./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +You can test the exported ONNX model with: + +.. code-block:: bash + + $ ./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following link: + + - ``_ diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 8ee40896e..ed0a07f5e 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -5,8 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -nj=1 -stage=-1 +stage=0 stop_stage=100 dl_dir=$PWD/download @@ -25,6 +24,17 @@ log() { log "dl_dir: $dl_dir" +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" @@ -113,5 +123,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --tokens data/tokens.txt fi fi - - diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py index c29a28479..1a8190014 100644 --- a/egs/ljspeech/TTS/vits/duration_predictor.py +++ b/egs/ljspeech/TTS/vits/duration_predictor.py @@ -14,7 +14,6 @@ from typing import Optional import torch import torch.nn.functional as F - from flow import ( ConvFlow, DilatedDepthSeparableConv, diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 154de4bf4..2068adeea 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -180,7 +180,13 @@ def export_model_onnx( model_filename, verbose=False, opset_version=opset_version, - input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "noise_scale_dur", + "alpha", + ], output_names=["audio"], dynamic_axes={ "tokens": {0: "N", 1: "T"}, diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py index 206bd5e3e..2b84f6434 100644 --- a/egs/ljspeech/TTS/vits/flow.py +++ b/egs/ljspeech/TTS/vits/flow.py @@ -13,7 +13,6 @@ import math from typing import Optional, Tuple, Union import torch - from transform import piecewise_rational_quadratic_transform diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index efb0e254c..66c8cedb1 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -16,9 +16,6 @@ from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F - -from icefall.utils import make_pad_mask - from duration_predictor import StochasticDurationPredictor from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder @@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock from text_encoder import TextEncoder from utils import get_random_segments +from icefall.utils import make_pad_mask + class VITSGenerator(torch.nn.Module): """Generator module in VITS, `Conditional Variational Autoencoder diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 91a35e360..cf0d20ae2 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -36,13 +36,12 @@ import k2 import torch import torch.nn as nn import torchaudio - -from train import get_model, get_params from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule from icefall.checkpoint import load_checkpoint from icefall.utils import AttributeDict, setup_logger -from tts_datamodule import LJSpeechTtsDataModule def get_parser(): @@ -107,12 +106,12 @@ def infer_dataset( for i in range(batch_size): torchaudio.save( str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), - audio[i:i + 1, :audio_lens[i]], + audio[i : i + 1, : audio_lens[i]], sample_rate=params.sampling_rate, ) torchaudio.save( str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), - audio_pred[i:i + 1, :audio_lens_pred[i]], + audio_pred[i : i + 1, : audio_lens_pred[i]], sample_rate=params.sampling_rate, ) @@ -144,14 +143,24 @@ def infer_dataset( audio_lens = batch["audio_lens"].tolist() cut_ids = [cut.id for cut in batch["cut"]] - audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred, _, durations = model.inference_batch( + text=tokens, text_lengths=tokens_lens + ) audio_pred = audio_pred.detach().cpu() # convert to samples - audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) futures.append( executor.submit( - _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + _save_worker, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, ) ) @@ -160,7 +169,9 @@ def infer_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) # return results for f in futures: f.result() diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py index 21aaad6e7..2f4dc9bc0 100644 --- a/egs/ljspeech/TTS/vits/loss.py +++ b/egs/ljspeech/TTS/vits/loss.py @@ -14,7 +14,6 @@ from typing import List, Tuple, Union import torch import torch.distributions as D import torch.nn.functional as F - from lhotse.features.kaldi import Wav2LogFilterBank diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py index 6b8a5be52..1104fb864 100644 --- a/egs/ljspeech/TTS/vits/posterior_encoder.py +++ b/egs/ljspeech/TTS/vits/posterior_encoder.py @@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple import torch +from wavenet import Conv1d, WaveNet from icefall.utils import make_pad_mask -from wavenet import WaveNet, Conv1d class PosteriorEncoder(torch.nn.Module): diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py index 2d6807cb7..f9a2a3786 100644 --- a/egs/ljspeech/TTS/vits/residual_coupling.py +++ b/egs/ljspeech/TTS/vits/residual_coupling.py @@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple, Union import torch - from flow import FlipFlow from wavenet import WaveNet diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 8acca7c02..686fee2a0 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -28,10 +28,10 @@ Use the onnx model to generate a wav: import argparse import logging + import onnxruntime as ort import torch import torchaudio - from tokenizer import Tokenizer diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index 9f337e45b..fcbae7103 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -169,9 +169,7 @@ class Transformer(nn.Module): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - x = self.encoder( - x, pos_emb, key_padding_mask=key_padding_mask - ) # (T, N, C) + x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) x = self.after_norm(x) @@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + self.self_attn = RelPositionMultiheadAttention( + d_model, num_heads, dropout=dropout + ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module): key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) """ # macaron style feed-forward module - src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + src = src + self.ff_scale * self.dropout( + self.feed_forward_macaron(self.norm_ff_macaron(src)) + ) # multi-head self-attention module src_attn = self.self_attn( @@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module): q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) - v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + v = ( + v.contiguous() + .view(seq_len, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) - p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + p = self.linear_pos(pos_emb).view( + pos_emb.size(0), -1, self.num_heads, self.head_dim + ) # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) p = p.permute(0, 2, 3, 1) @@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch_size, num_head, seq_len, seq_len) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) - matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift( + matrix_bd + ) # (batch_size, num_head, seq_len, seq_len) # (batch_size, num_head, seq_len, seq_len) attn_output_weights = (matrix_ac + matrix_bd) * scaling - attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, seq_len) @@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module): # (batch_size * num_head, seq_len, head_dim) attn_output = torch.bmm(attn_output_weights, v) - assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + assert attn_output.shape == ( + batch_size * self.num_heads, + seq_len, + self.head_dim, + ) attn_output = ( - attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, batch_size, self.embed_dim) ) # (seq_len, batch_size, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 0678b26fe..70f1240b4 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -78,7 +78,9 @@ class Tokenizer(object): return token_ids_list - def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + def tokens_to_token_ids( + self, tokens_list: List[str], intersperse_blank: bool = True + ): """ Args: tokens_list: diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index eb43a4cc9..71c4224fa 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -18,21 +18,25 @@ import argparse import logging -import numpy as np from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 +import numpy as np import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed -from torch.optim import Optimizer +from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS from icefall import diagnostics from icefall.checkpoint import load_checkpoint @@ -41,11 +45,6 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, setup_logger, str2bool -from tokenizer import Tokenizer -from tts_datamodule import LJSpeechTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -385,11 +384,12 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) loss_info = MetricsTracker() - loss_info['samples'] = batch_size + loss_info["samples"] = batch_size try: with autocast(enabled=params.use_fp16): @@ -446,7 +446,9 @@ def train_one_epoch( # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: @@ -482,9 +484,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train @@ -492,19 +492,34 @@ def train_one_epoch( if "returned_sample" in stats_g: speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] tb_writer.add_audio( - "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_audio( - "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_image( - "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", ) tb_writer.add_image( - "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", ) - if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info, (speech_hat, speech) = compute_validation_loss( params=params, @@ -523,10 +538,16 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) tb_writer.add_audio( - "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] @@ -555,11 +576,17 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device) loss_info = MetricsTracker() - loss_info['samples'] = batch_size + loss_info["samples"] = batch_size # forward discriminator loss_d, stats_d = model( @@ -596,12 +623,17 @@ def compute_validation_loss( if batch_idx == 0 and rank == 0: inner_model = model.module if isinstance(model, DDP) else model audio_pred, _, duration = inner_model.inference( - text=tokens[0, :tokens_lens[0].item()] + text=tokens[0, : tokens_lens[0].item()] ) audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) - audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() returned_sample = (audio_pred, audio_gt) if world_size > 1: @@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom( batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) try: # for discriminator with autocast(enabled=params.use_fp16): diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 0fcbb92c1..81bb9ed13 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, - SpeechSynthesisDataset, PrecomputedFeatures, SimpleCutSampler, SpecAugment, + SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py index 2a3dae900..6a067f596 100644 --- a/egs/ljspeech/TTS/vits/utils.py +++ b/egs/ljspeech/TTS/vits/utils.py @@ -14,15 +14,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union import collections import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler -from pathlib import Path from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -97,23 +97,23 @@ def plot_feature(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib + matplotlib.use("Agg") MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') + mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index d5e20a578..b4f0c21e6 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn -from torch.cuda.amp import autocast - +from generator import VITSGenerator from hifigan import ( HiFiGANMultiPeriodDiscriminator, HiFiGANMultiScaleDiscriminator, @@ -25,9 +24,8 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) +from torch.cuda.amp import autocast from utils import get_segments -from generator import VITSGenerator - AVAILABLE_GENERATERS = { "vits_generator": VITSGenerator, @@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = { class VITS(nn.Module): - """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` - """ + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" def __init__( self, diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py index fbe1be52b..5db461d5c 100644 --- a/egs/ljspeech/TTS/vits/wavenet.py +++ b/egs/ljspeech/TTS/vits/wavenet.py @@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. """ -import math import logging - +import math from typing import Optional, Tuple import torch diff --git a/egs/vctk/TTS/local/compute_spectrogram_vctk.py b/egs/vctk/TTS/local/compute_spectrogram_vctk.py new file mode 100755 index 000000000..440ac1245 --- /dev/null +++ b/egs/vctk/TTS/local/compute_spectrogram_vctk.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the VCTK dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_vctk(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(32, os.cpu_count()) + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "vctk" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet + ).resample(sampling_rate=sampling_rate) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_vctk() diff --git a/egs/vctk/TTS/local/display_manifest_statistics.py b/egs/vctk/TTS/local/display_manifest_statistics.py new file mode 100755 index 000000000..0472e2cea --- /dev/null +++ b/egs/vctk/TTS/local/display_manifest_statistics.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in vits/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/spectrogram/vctk_cuts_all.jsonl.gz" + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 41:02:18 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.4 │ +├───────────────────────────┼──────────┤ +│ std │ 1.2 │ +├───────────────────────────┼──────────┤ +│ min │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.6 │ +├───────────────────────────┼──────────┤ +│ 50% │ 3.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 3.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 8.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 12.1 │ +├───────────────────────────┼──────────┤ +│ max │ 16.6 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 43873 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 41:02:18 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 41:02:18 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:01 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +""" diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py new file mode 100755 index 000000000..c6636c3ad --- /dev/null +++ b/egs/vctk/TTS/local/prepare_token_file.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict + +from lhotse import load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-file", + type=Path, + default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"), + help="Path to the manifest file", + ) + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the tokens", + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_token2id(manifest_file: Path) -> Dict[str, int]: + """Return a dict that maps token to IDs.""" + extra_tokens = [ + "", # 0 for blank + "", # 1 for sos and eos symbols. + "", # 2 for OOV + ] + all_tokens = set() + + cut_set = load_manifest(manifest_file) + + for cut in cut_set: + # Each cut only contain one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + for t in cut.tokens: + all_tokens.add(t) + + all_tokens = extra_tokens + list(all_tokens) + + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} + return token2id + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + manifest_file = Path(args.manifest_file) + out_file = Path(args.tokens) + + token2id = get_token2id(manifest_file) + write_mapping(out_file, token2id) diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py new file mode 100755 index 000000000..32e1c7dfa --- /dev/null +++ b/egs/vctk/TTS/local/prepare_tokens_vctk.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest +from tqdm.auto import tqdm + + +def prepare_tokens_vctk(): + output_dir = Path("data/spectrogram") + prefix = "vctk" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in tqdm(cut_set): + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_vctk() diff --git a/egs/vctk/TTS/local/validate_manifest.py b/egs/vctk/TTS/local/validate_manifest.py new file mode 100755 index 000000000..cd466303e --- /dev/null +++ b/egs/vctk/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh new file mode 100755 index 000000000..87150ad31 --- /dev/null +++ b/egs/vctk/TTS/prepare.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/VCTK, + # you can create a symlink + # + # ln -sfv /path/to/VCTK $dl_dir/VCTK + # + if [ ! -d $dl_dir/VCTK ]; then + lhotse download vctk $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare VCTK manifest" + # We assume that you have downloaded the VCTK corpus + # to $dl_dir/VCTK + mkdir -p data/manifests + if [ ! -e data/manifests/.vctk.done ]; then + lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests + touch data/manifests/.vctk.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute spectrogram for VCTK" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.vctk.done ]; then + ./local/compute_spectrogram_vctk.py + touch data/spectrogram/.vctk.done + fi + + if [ ! -e data/spectrogram/.vctk-validated.done ]; then + log "Validating data/fbank for VCTK" + ./local/validate_manifest.py \ + data/spectrogram/vctk_cuts_all.jsonl.gz + touch data/spectrogram/.vctk-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare phoneme tokens for VCTK" + if [ ! -e data/spectrogram/.vctk_with_token.done ]; then + ./local/prepare_tokens_vctk.py + mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/vctk_cuts_all.jsonl.gz + touch data/spectrogram/.vctk_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the VCTK cuts into train, valid and test sets" + if [ ! -e data/spectrogram/.vctk_split.done ]; then + lhotse subset --last 600 \ + data/spectrogram/vctk_cuts_all.jsonl.gz \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz \ + data/spectrogram/vctk_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz \ + data/spectrogram/vctk_cuts_test.jsonl.gz + + rm data/spectrogram/vctk_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/vctk_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/vctk_cuts_all.jsonl.gz \ + data/spectrogram/vctk_cuts_train.jsonl.gz + touch data/spectrogram/.vctk_split.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" + # We assume you have installed g2p_en and espnet_tts_frontend. + # If not, please install them with: + # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py \ + --manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \ + --tokens data/tokens.txt + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate speakers file" + if [ ! -e data/speakers.txt ]; then + gunzip -c data/manifests/vctk_supervisions_all.jsonl.gz \ + | jq '.speaker' | sed 's/"//g' \ + | sort | uniq > data/speakers.txt + fi +fi diff --git a/egs/vctk/TTS/shared b/egs/vctk/TTS/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/vctk/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/vctk/TTS/vits/duration_predictor.py b/egs/vctk/TTS/vits/duration_predictor.py new file mode 120000 index 000000000..9972b476f --- /dev/null +++ b/egs/vctk/TTS/vits/duration_predictor.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py new file mode 100755 index 000000000..7c9664cc1 --- /dev/null +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate two files inside vits/exp: + - vits-epoch-1000.onnx + - vits-epoch-1000.int8.onnx (quantizated model) + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + speaker: int = 20, + alpha: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + speaker (int): + Speaker ID. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + sids=speaker, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + speaker = torch.tensor([1], dtype=torch.int64) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "noise_scale_dur", + "speaker", + "alpha", + ], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + "speaker": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "VITS", + "version": "1", + "model_author": "k2-fsa", + "comment": "VITS generator", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + params.num_spks = len(speaker_map) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model = model.generator + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/vctk/TTS/vits/flow.py b/egs/vctk/TTS/vits/flow.py new file mode 120000 index 000000000..e65d91ea7 --- /dev/null +++ b/egs/vctk/TTS/vits/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/generator.py b/egs/vctk/TTS/vits/generator.py new file mode 120000 index 000000000..611679bfa --- /dev/null +++ b/egs/vctk/TTS/vits/generator.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/hifigan.py b/egs/vctk/TTS/vits/hifigan.py new file mode 120000 index 000000000..5ac025de7 --- /dev/null +++ b/egs/vctk/TTS/vits/hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py new file mode 100755 index 000000000..06c25f02e --- /dev/null +++ b/egs/vctk/TTS/vits/infer.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +import torch.nn as nn +import torchaudio +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import VctkTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + subset: str, + params: AttributeDict, + model: nn.Module, + tokenizer: Tokenizer, + speaker_map: Dict[str, int], +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + # Background worker save audios to disk. + def _save_worker( + subset: str, + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), + audio[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"), + audio_pred[i : i + 1, : audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + speakers = ( + torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]) + .int() + .to(device) + ) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + audio_pred, _, durations = model.inference_batch( + text=tokens, + text_lengths=tokens_lens, + sids=speakers, + ) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) + + futures.append( + executor.submit( + _save_worker, + subset, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + VctkTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + # we need cut ids to display recognition results. + args.return_cuts = True + vctk = VctkTtsDataModule(args) + speaker_map = vctk.speakers() + params.num_spks = len(speaker_map) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to(device) + model.eval() + + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + test_cuts = vctk.test_cuts() + test_dl = vctk.test_dataloaders(test_cuts) + + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) + + infer_sets = {"test": test_dl, "valid": valid_dl} + + for subset, dl in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + tokenizer=tokenizer, + speaker_map=speaker_map, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/vctk/TTS/vits/loss.py b/egs/vctk/TTS/vits/loss.py new file mode 120000 index 000000000..672e5ff68 --- /dev/null +++ b/egs/vctk/TTS/vits/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/monotonic_align b/egs/vctk/TTS/vits/monotonic_align new file mode 120000 index 000000000..71934e7cc --- /dev/null +++ b/egs/vctk/TTS/vits/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/vctk/TTS/vits/posterior_encoder.py b/egs/vctk/TTS/vits/posterior_encoder.py new file mode 120000 index 000000000..41d64a3a6 --- /dev/null +++ b/egs/vctk/TTS/vits/posterior_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/residual_coupling.py b/egs/vctk/TTS/vits/residual_coupling.py new file mode 120000 index 000000000..f979adbf0 --- /dev/null +++ b/egs/vctk/TTS/vits/residual_coupling.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py new file mode 100755 index 000000000..757e67fc1 --- /dev/null +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +from pathlib import Path + +import onnxruntime as ort +import torch +import torchaudio +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__( + self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor + ) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: noise_scale_dur.numpy(), + self.model.get_inputs()[4].name: speaker.numpy(), + self.model.get_inputs()[5].name: alpha.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + args.num_spks = len(speaker_map) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + speaker = torch.tensor([1], dtype=torch.int64) # (1, ) + audio = model(tokens, tokens_lens, speaker) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/vctk/TTS/vits/text_encoder.py b/egs/vctk/TTS/vits/text_encoder.py new file mode 120000 index 000000000..0efba277e --- /dev/null +++ b/egs/vctk/TTS/vits/text_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/tokenizer.py b/egs/vctk/TTS/vits/tokenizer.py new file mode 120000 index 000000000..057b0dc4b --- /dev/null +++ b/egs/vctk/TTS/vits/tokenizer.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py new file mode 100755 index 000000000..56f167a17 --- /dev/null +++ b/egs/vctk/TTS/vits/train.py @@ -0,0 +1,1000 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import VctkTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 22050, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + generator_params = { + "hidden_channels": 192, + "spks": params.num_spks, + "langs": None, + "spk_embed_dim": None, + "global_channels": 256, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + generator_params=generator_params, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input( + batch: dict, + tokenizer: Tokenizer, + device: torch.device, + speaker_map: Dict[str, int], +): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + speakers = ( + torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) + ) + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + speaker_map=speaker_map, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()], + sids=speakers[0], + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + speaker_map: Dict[str, int], + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + vctk = VctkTtsDataModule(args) + + train_cuts = vctk.train_cuts() + speaker_map = vctk.speakers() + params.num_spks = len(speaker_map) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = vctk.train_dataloaders(train_cuts) + + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + speaker_map=speaker_map, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + speaker_map=speaker_map, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + VctkTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/vctk/TTS/vits/transform.py b/egs/vctk/TTS/vits/transform.py new file mode 120000 index 000000000..962647408 --- /dev/null +++ b/egs/vctk/TTS/vits/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py new file mode 100644 index 000000000..8b2a96b09 --- /dev/null +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -0,0 +1,338 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class VctkTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") + + @lru_cache() + def speakers(self) -> Dict[str, int]: + logging.info("About to get speakers") + with open(self.args.speakers) as f: + speakers = {line.strip(): i for i, line in enumerate(f)} + return speakers diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py new file mode 120000 index 000000000..085e764b4 --- /dev/null +++ b/egs/vctk/TTS/vits/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py new file mode 120000 index 000000000..1f58cf6fe --- /dev/null +++ b/egs/vctk/TTS/vits/vits.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py new file mode 120000 index 000000000..28f0a78ee --- /dev/null +++ b/egs/vctk/TTS/vits/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file diff --git a/requirements-tts.txt b/requirements-tts.txt new file mode 100644 index 000000000..c30e23d54 --- /dev/null +++ b/requirements-tts.txt @@ -0,0 +1,6 @@ +# for TTS recipes +matplotlib==3.8.2 +cython==3.0.6 +numba==0.58.1 +g2p_en==2.1.0 +espnet_tts_frontend==0.0.3 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9502fcbd2..a1a46ae64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ tensorboard typeguard dill black==22.3.0 +onnx==1.15.0 +onnxruntime==1.16.3 \ No newline at end of file From b87ed26c09e9f5bb29174dd01f13670fb6124583 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:33:45 +0800 Subject: [PATCH 055/216] Normalize dockerfile (#1400) --- docker/torch1.12.1-cuda11.3.dockerfile | 3 +-- docker/torch1.13.0-cuda11.6.dockerfile | 3 +-- docker/torch1.9.0-cuda10.2.dockerfile | 3 +-- docker/torch2.0.0-cuda11.7.dockerfile | 3 +-- docker/torch2.1.0-cuda11.8.dockerfile | 3 +-- docker/torch2.1.0-cuda12.1.dockerfile | 3 +-- 6 files changed, 6 insertions(+), 12 deletions(-) diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index ed746abe3..deb5715cc 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index 9657866e5..afc6c1b84 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index a92af7ad0..9ff225b54 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -58,7 +58,6 @@ RUN pip uninstall -y tqdm && \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index 07296e6f0..db8076560 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index e500e9a6a..b006b0d96 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index c3f12323e..1b078dc22 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ From bda72f86fffe591d334630da522dba4cf5c66341 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 8 Dec 2023 06:32:40 +0800 Subject: [PATCH 056/216] minor adjustments to the VITS recipes for onnx runtime (#1405) --- egs/ljspeech/TTS/vits/export-onnx.py | 4 ++-- egs/ljspeech/TTS/vits/test_onnx.py | 4 ++-- egs/vctk/TTS/vits/export-onnx.py | 4 ++-- egs/vctk/TTS/vits/test_onnx.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 2068adeea..bca6aec99 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -176,7 +176,7 @@ def export_model_onnx( torch.onnx.export( model, - (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur), model_filename, verbose=False, opset_version=opset_version, @@ -184,8 +184,8 @@ def export_model_onnx( "tokens", "tokens_lens", "noise_scale", - "noise_scale_dur", "alpha", + "noise_scale_dur", ], output_names=["audio"], dynamic_axes={ diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 686fee2a0..fcbc1d663 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -92,8 +92,8 @@ class OnnxModel: self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: alpha.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), }, )[0] return torch.from_numpy(out) diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index 7c9664cc1..cfc74fd0a 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -187,7 +187,7 @@ def export_model_onnx( torch.onnx.export( model, - (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha), + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker), model_filename, verbose=False, opset_version=opset_version, @@ -195,9 +195,9 @@ def export_model_onnx( "tokens", "tokens_lens", "noise_scale", + "alpha", "noise_scale_dur", "speaker", - "alpha", ], output_names=["audio"], dynamic_axes={ diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py index 757e67fc1..d85c0a27b 100755 --- a/egs/vctk/TTS/vits/test_onnx.py +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -101,9 +101,9 @@ class OnnxModel: self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: speaker.numpy(), - self.model.get_inputs()[5].name: alpha.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), + self.model.get_inputs()[5].name: speaker.numpy(), }, )[0] return torch.from_numpy(out) From e9ec827de76856e38af7a884b878ca3a84f64bb9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Dec 2023 14:29:24 +0800 Subject: [PATCH 057/216] Rename zipformer2 to zipformer_for_ncnn_export_only to avoid confusion. (#1407) --- .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../{zipformer2.py => zipformer_for_ncnn_export_only.py} | 0 .../pruned_transducer_stateless7_streaming_multi/zipformer2.py | 1 - 12 files changed, 7 insertions(+), 8 deletions(-) delete mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py delete mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py delete mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py rename egs/librispeech/ASR/pruned_transducer_stateless7_streaming/{zipformer2.py => zipformer_for_ncnn_export_only.py} (100%) delete mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 3c13c19c6..0fba3b58f 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -66,7 +66,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 61a3f27db..0426bc9a3 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -67,7 +67,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index acde72d80..685f6ece6 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -70,7 +70,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index cd26db6f3..9a6d2155b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -69,7 +69,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py similarity index 100% rename from egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py rename to egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py deleted file mode 120000 index d3625f478..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file From df56aff31ea9b95aa3d9672398a0771dcb8eacc5 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 8 Dec 2023 21:11:31 +0800 Subject: [PATCH 058/216] minor fixes to the vits onnx exportation scripts (#1408) --- egs/ljspeech/TTS/vits/export-onnx.py | 2 +- egs/vctk/TTS/vits/export-onnx.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index bca6aec99..36a9de27f 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -115,8 +115,8 @@ class OnnxModel(nn.Module): tokens: torch.Tensor, tokens_lens: torch.Tensor, noise_scale: float = 0.667, - noise_scale_dur: float = 0.8, alpha: float = 1.0, + noise_scale_dur: float = 0.8, ) -> Tuple[torch.Tensor, torch.Tensor]: """Please see the help information of VITS.inference_batch diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index cfc74fd0a..667ac284b 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -121,9 +121,9 @@ class OnnxModel(nn.Module): tokens: torch.Tensor, tokens_lens: torch.Tensor, noise_scale: float = 0.667, + alpha: float = 1.0, noise_scale_dur: float = 0.8, speaker: int = 20, - alpha: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Please see the help information of VITS.inference_batch From b0f70c9d042da734a5df988b98412e5def6b8072 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 10 Dec 2023 11:38:39 +0800 Subject: [PATCH 059/216] Fix torch.jit.script() export for pruned_transducer_stateless2 (#1410) --- egs/librispeech/ASR/pruned_transducer_stateless2/export.py | 2 ++ egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py | 1 + .../ASR/pruned_transducer_stateless2/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index e02afa892..e2db98f73 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -49,6 +49,7 @@ from pathlib import Path import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint @@ -198,6 +199,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file From 20a82c9abf6664b645c28dd6b3d629cf85e40231 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Dec 2023 18:13:26 +0800 Subject: [PATCH 060/216] first commit (#1411) --- .github/scripts/multi-zh-hans.sh | 88 +++++++++++++++++++ .github/workflows/multi-zh-hans.yml | 84 ++++++++++++++++++ .../ASR/zipformer/export-onnx-streaming.py | 14 ++- 3 files changed, 182 insertions(+), 4 deletions(-) create mode 100755 .github/scripts/multi-zh-hans.sh create mode 100644 .github/workflows/multi-zh-hans.yml diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh new file mode 100755 index 000000000..4ede7f43e --- /dev/null +++ b/.github/scripts/multi-zh-hans.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "pwd: $PWD" + +cd egs/multi_zh-hans/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +cd exp/ +git lfs pull --include pretrained.pt +rm -fv epoch-20.pt +rm -fv *.onnx +ln -s pretrained.pt epoch-20.pt +cd ../data/lang_bpe_2000 +git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model +popd + +log "----------------------------------------" +log "Export streaming ONNX transducer models " +log "----------------------------------------" + +./zipformer/export-onnx-streaming.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --causal 1 \ + --avg 1 \ + --epoch 20 \ + --use-averaged-model 0 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 0 + +ls -lh $repo/exp + +log "------------------------------------------------------------" +log "Test export streaming ONNX transducer models (Python code) " +log "------------------------------------------------------------" + +log "test fp32" +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav + +log "test int8" +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav + +log "Upload models to huggingface" +git config --global user.name "k2-fsa" +git config --global user.email "xxx@gmail.com" + +url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $url +dst=$(basename $url) +cp -v $repo/exp/*.onnx $dst +cp -v $repo/data/lang_bpe_2000/tokens.txt $dst +mkdir -p $dst/test_wavs +cp -v $repo/test_wavs/*.wav $dst/test_wavs +cd $dst +git lfs track "*.onnx" +git add . +git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true + +log "Upload models to https://github.com/k2-fsa/sherpa-onnx" +rm -rf .git +rm -fv .gitattributes +cd .. +tar cjfv $dst.tar.bz2 $dst +mv -v $dst.tar.bz2 ../../../ diff --git a/.github/workflows/multi-zh-hans.yml b/.github/workflows/multi-zh-hans.yml new file mode 100644 index 000000000..439300b5f --- /dev/null +++ b/.github/workflows/multi-zh-hans.yml @@ -0,0 +1,84 @@ +name: run-multi-zh-hans + +on: + push: + branches: + - master + - upload-ctc-model + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: run-multi-zh-hans-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write + +jobs: + multi-zh-hans: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: export-model + shell: bash + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/multi-zh-hans.sh + ls -lh + + - name: upload model to https://github.com/k2-fsa/sherpa-onnx + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index e2c7d7d95..6bc9b1858 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -614,7 +614,9 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: @@ -625,7 +627,9 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ @@ -653,7 +657,8 @@ def main(): filename_start=filename_start, filename_end=filename_end, device=device, - ) + ), + strict=False, ) else: assert params.avg > 0, params.avg @@ -671,7 +676,8 @@ def main(): filename_start=filename_start, filename_end=filename_end, device=device, - ) + ), + strict=False, ) model.to("cpu") From 9e9fe7954d13c5ee8f10e990b358cf8c752a24e6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Dec 2023 18:57:04 +0800 Subject: [PATCH 061/216] Upload gigaspeech zipformer models in CI (#1412) --- .github/scripts/multi-zh-hans.sh | 3 +- .../run-gigaspeech-zipformer-2023-10-17.sh | 74 +++++++++++++++++-- .github/workflows/multi-zh-hans.yml | 5 -- .../run-gigaspeech-zipformer-2023-10-17.yml | 14 ++++ 4 files changed, 85 insertions(+), 11 deletions(-) diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh index 4ede7f43e..2dd1bce42 100755 --- a/.github/scripts/multi-zh-hans.sh +++ b/.github/scripts/multi-zh-hans.sh @@ -45,7 +45,7 @@ log "----------------------------------------" ls -lh $repo/exp log "------------------------------------------------------------" -log "Test export streaming ONNX transducer models (Python code) " +log "Test exported streaming ONNX transducer models (Python code)" log "------------------------------------------------------------" log "test fp32" @@ -73,6 +73,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $url dst=$(basename $url) cp -v $repo/exp/*.onnx $dst cp -v $repo/data/lang_bpe_2000/tokens.txt $dst +cp -v $repo/data/lang_bpe_2000/bpe.model $dst mkdir -p $dst/test_wavs cp -v $repo/test_wavs/*.wav $dst/test_wavs cd $dst diff --git a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh index 6bb0b9ebc..329896ef6 100755 --- a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh +++ b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh @@ -26,16 +26,80 @@ git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/jit_script.pt" git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt +rm epoch-30.pt +ln -s pretrained.pt epoch-30.pt +rm *.onnx +ls -lh popd +log "----------------------------------------" +log "Export ONNX transducer models " +log "----------------------------------------" + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 30 \ + --avg 1 \ + --exp-dir $repo/exp + +ls -lh $repo/exp + +log "------------------------------------------------------------" +log "Test exported ONNX transducer models (Python code) " +log "------------------------------------------------------------" + +log "test fp32" +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-30-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-30-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-30-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "test int8" +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-30-avg-1.int8.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-30-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-30-avg-1.int8.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Upload models to huggingface" +git config --global user.name "k2-fsa" +git config --global user.email "xxx@gmail.com" + +url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-gigaspeech-2023-12-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $url +dst=$(basename $url) +cp -v $repo/exp/*.onnx $dst +cp -v $repo/data/lang_bpe_500/tokens.txt $dst +cp -v $repo/data/lang_bpe_500/bpe.model $dst +mkdir -p $dst/test_wavs +cp -v $repo/test_wavs/*.wav $dst/test_wavs +cd $dst +git lfs track "*.onnx" +git add . +git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true + +log "Upload models to https://github.com/k2-fsa/sherpa-onnx" +rm -rf .git +rm -fv .gitattributes +cd .. +tar cjfv $dst.tar.bz2 $dst +ls -lh +mv -v $dst.tar.bz2 ../../../ + log "Export to torchscript model" ./zipformer/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ + --epoch 30 \ --avg 1 \ --jit 1 @@ -67,7 +131,7 @@ echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then mkdir -p zipformer/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt + ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-30.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ ls -lh data @@ -83,7 +147,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == ./zipformer/decode.py \ --decoding-method $method \ - --epoch 999 \ + --epoch 30 \ --avg 1 \ --use-averaged-model 0 \ --max-duration $max_duration \ diff --git a/.github/workflows/multi-zh-hans.yml b/.github/workflows/multi-zh-hans.yml index 439300b5f..9081047de 100644 --- a/.github/workflows/multi-zh-hans.yml +++ b/.github/workflows/multi-zh-hans.yml @@ -2,11 +2,6 @@ name: run-multi-zh-hans on: push: - branches: - - master - - upload-ctc-model - - pull_request: branches: - master diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml index 7572f4b5f..87090e310 100644 --- a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml +++ b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml @@ -21,6 +21,7 @@ on: push: branches: - master + pull_request: types: [labeled] @@ -33,6 +34,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: run_gigaspeech_2023_10_17_zipformer-${{ github.ref }} cancel-in-progress: true @@ -85,6 +88,7 @@ jobs: env: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | mkdir -p egs/gigaspeech/ASR/data ln -sfv ~/tmp/fbank-libri egs/gigaspeech/ASR/data/fbank @@ -97,6 +101,16 @@ jobs: .github/scripts/run-gigaspeech-zipformer-2023-10-17.sh + - name: upload model to https://github.com/k2-fsa/sherpa-onnx + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models + - name: Display decoding results for gigaspeech zipformer if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' shell: bash From d0da509055468f600808a3d7cd8885649d12b96d Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 13 Dec 2023 10:33:28 +0800 Subject: [PATCH 062/216] Support ONNX export for Streaming CTC Encoder (#1413) * Create export-onnx-streaming-ctc.py * doc_str updated Co-authored-by: Fangjun Kuang --------- Co-authored-by: Fangjun Kuang --- .../zipformer/export-onnx-streaming-ctc.py | 570 ++++++++++++++++++ 1 file changed, 570 insertions(+) create mode 100755 egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 100755 index 000000000..3c0f74005 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1,570 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang, Zengrui Jin) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a CTC model from PyTorch to ONNX. + + +1. Download the pre-trained streaming model with CTC head + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 \ + --use-ctc 1 + +The --chunk-size in training is "16,32,64,-1", so we select one of them +(excluding -1) during streaming export. The same applies to `--left-context`, +whose value is "64,128,256,-1". + +It will generate the following file inside $repo/exp: + + - ctc-epoch-99-avg-1-chunk-16-left-64.onnx + +See ./onnx_pretrained-streaming-ctc.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for Zipformer and the ctc_head""" + + def __init__( + self, + encoder: Zipformer2, + encoder_embed: nn.Module, + ctc_output: nn.Module, + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + ctc_output: + The ctc head. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.ctc_output = ctc_output + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + self.pad_length = 7 + 2 * 3 + + def forward( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + N = x.size(0) + T = self.chunk_size * 2 + self.pad_length + x_lens = torch.tensor([T] * N, device=x.device) + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=x, + x_lens=x_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) + + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) + encoder_states = states[:-2] + logging.info(f"len_encoder_states={len(encoder_states)}") + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.ctc_output(encoder_out) + # Now encoder_out is of shape (N, T, ctc_output_dim) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + + return encoder_out, new_states + + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) + states.append(processed_lens) + + return states + + +def export_streaming_ctc_model_onnx( + model: OnnxModel, + encoder_filename: str, + opset_version: int = 11, +) -> None: + model.encoder.__class__.forward = model.encoder.__class__.streaming_forward + + decode_chunk_len = model.chunk_size * 2 + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + T = decode_chunk_len + model.pad_length + + x = torch.rand(1, T, 80, dtype=torch.float32) + init_state = model.get_init_states() + num_encoders = len(model.encoder.encoder_dim) + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + logging.info(f"{name}.shape: {tensors[0].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + logging.info(f"{name}.shape: {tensors[1].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + logging.info(f"{name}.shape: {tensors[2].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + logging.info(f"{name}.shape: {tensors[3].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + logging.info(f"{name}.shape: {tensors[4].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + logging.info(f"{name}.shape: {tensors[5].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + num_encoder_layers = ",".join(map(str, model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, model.encoder.encoder_dim)) + cnn_module_kernels = ",".join(map(str, model.encoder.cnn_module_kernel)) + ds = model.encoder.downsampling_factor + left_context_len = model.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + query_head_dims = ",".join(map(str, model.encoder.query_head_dim)) + value_head_dims = ",".join(map(str, model.encoder.value_head_dim)) + num_heads = ",".join(map(str, model.encoder.num_heads)) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "streaming ctc zipformer2", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 32+7+2*3=45 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + "query_head_dims": query_head_dims, + "value_head_dims": value_head_dims, + "num_heads": num_heads, + } + logging.info(f"meta_data: {meta_data}") + + for i in range(len(init_state[:-2]) // 6): + build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + embed_states = init_state[-2] + name = "embed_states" + logging.info(f"{name}.shape: {embed_states.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size,) + processed_lens = init_state[-1] + name = "processed_lens" + logging.info(f"{name}.shape: {processed_lens.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "log_probs": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + model = OnnxModel( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + ctc_output=model.ctc_output, + ) + + total_num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + suffix += f"-chunk-{params.chunk_size}" + suffix += f"-left-{params.left_context_frames}" + + opset_version = 13 + + logging.info("Exporting model") + model_filename = params.exp_dir / f"ctc-{suffix}.onnx" + export_streaming_ctc_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported model to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() From f85f0252a9f05c10e6782a693c87a1fede2f83c7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 13 Dec 2023 17:34:12 +0800 Subject: [PATCH 063/216] Add greedy search for streaming zipformer CTC. (#1415) --- .github/scripts/multi-zh-hans.sh | 79 +++- .../onnx_pretrained-streaming-ctc.py | 426 ++++++++++++++++++ .../zipformer/onnx_pretrained-streaming.py | 2 +- .../zipformer/export-onnx-streaming-ctc.py | 1 + .../onnx_pretrained-streaming-ctc.py | 1 + 5 files changed, 503 insertions(+), 6 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh index 2dd1bce42..427d8887b 100755 --- a/.github/scripts/multi-zh-hans.sh +++ b/.github/scripts/multi-zh-hans.sh @@ -2,6 +2,10 @@ set -ex +git config --global user.name "k2-fsa" +git config --global user.email "csukuangfj@gmail.com" +git config --global lfs.allowincompletepush true + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} @@ -24,9 +28,73 @@ rm -fv epoch-20.pt rm -fv *.onnx ln -s pretrained.pt epoch-20.pt cd ../data/lang_bpe_2000 +ls -lh git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model +git lfs pull --include "*.model" +ls -lh popd +log "----------------------------------------" +log "Export streaming ONNX CTC models " +log "----------------------------------------" +./zipformer/export-onnx-streaming-ctc.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --causal 1 \ + --avg 1 \ + --epoch 20 \ + --use-averaged-model 0 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 1 + +ls -lh $repo/exp/ + +log "------------------------------------------------------------" +log "Test exported streaming ONNX CTC models (greedy search) " +log "------------------------------------------------------------" + +test_wavs=( +DEV_T0000000000.wav +DEV_T0000000001.wav +DEV_T0000000002.wav +TEST_MEETING_T0000000113.wav +TEST_MEETING_T0000000219.wav +TEST_MEETING_T0000000351.wav +) + +for w in ${test_wavs[@]}; do + ./zipformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/$w +done + +log "Upload onnx CTC models to huggingface" +url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $url +dst=$(basename $url) +cp -v $repo/exp/ctc*.onnx $dst +cp -v $repo/data/lang_bpe_2000/tokens.txt $dst +cp -v $repo/data/lang_bpe_2000/bpe.model $dst +mkdir -p $dst/test_wavs +cp -v $repo/test_wavs/*.wav $dst/test_wavs +cd $dst +git lfs track "*.onnx" "bpe.model" +ls -lh +file bpe.model +git status +git add . +git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true + +log "Upload models to https://github.com/k2-fsa/sherpa-onnx" +rm -rf .git +rm -fv .gitattributes +cd .. +tar cjfv $dst.tar.bz2 $dst +ls -lh *.tar.bz2 +mv -v $dst.tar.bz2 ../../../ + log "----------------------------------------" log "Export streaming ONNX transducer models " log "----------------------------------------" @@ -64,20 +132,20 @@ log "test int8" --tokens $repo/data/lang_bpe_2000/tokens.txt \ $repo/test_wavs/DEV_T0000000000.wav -log "Upload models to huggingface" -git config --global user.name "k2-fsa" -git config --global user.email "xxx@gmail.com" +log "Upload onnx transducer models to huggingface" url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12 GIT_LFS_SKIP_SMUDGE=1 git clone $url dst=$(basename $url) -cp -v $repo/exp/*.onnx $dst +cp -v $repo/exp/encoder*.onnx $dst +cp -v $repo/exp/decoder*.onnx $dst +cp -v $repo/exp/joiner*.onnx $dst cp -v $repo/data/lang_bpe_2000/tokens.txt $dst cp -v $repo/data/lang_bpe_2000/bpe.model $dst mkdir -p $dst/test_wavs cp -v $repo/test_wavs/*.wav $dst/test_wavs cd $dst -git lfs track "*.onnx" +git lfs track "*.onnx" bpe.model git add . git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true @@ -86,4 +154,5 @@ rm -rf .git rm -fv .gitattributes cd .. tar cjfv $dst.tar.bz2 $dst +ls -lh *.tar.bz2 mv -v $dst.tar.bz2 ../../../ diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py new file mode 100755 index 000000000..44546cae5 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming-ctc.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 1 + +It will generate the following 2 files inside $repo/exp: + + - ctc-epoch-99-avg-1-chunk-16-left-128.int8.onnx + - ctc-epoch-99-avg-1-chunk-16-left-128.onnx + +You can use either the ``int8.onnx`` model or just the ``.onnx`` model. + +3. Run this file with the exported ONNX models + +./zipformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-99-avg-1-chunk-16-left-128.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Tuple + +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(model_filename) + + def init_model(self, model_filename: str): + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_states() + + def init_states(self, batch_size: int = 1): + meta = self.model.get_modelmeta().custom_metadata_map + logging.info(f"meta={meta}") + + model_type = meta["model_type"] + assert model_type == "zipformer2", model_type + + decode_chunk_len = int(meta["decode_chunk_len"]) + T = int(meta["T"]) + + num_encoder_layers = meta["num_encoder_layers"] + encoder_dims = meta["encoder_dims"] + cnn_module_kernels = meta["cnn_module_kernels"] + left_context_len = meta["left_context_len"] + query_head_dims = meta["query_head_dims"] + value_head_dims = meta["value_head_dims"] + num_heads = meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def _build_model_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + model_input = {"x": x.numpy()} + model_output = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + model_input[name] = tensors[0] + model_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + model_input[name] = tensors[1] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + model_input[name] = tensors[2] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + model_input[name] = tensors[3] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + model_input[name] = tensors[4] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + model_input[name] = tensors[5] + model_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + model_input[name] = embed_states + model_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + model_input[name] = processed_lens + model_output.append(f"new_{name}") + + return model_input, model_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size) + where T' is usually equal to ((T-7)//2 - 3)//2 + """ + model_input, model_output_names = self._build_model_input_output(x) + + out = self.model.run(model_output_names, model_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + log_probs: torch.Tensor, +) -> List[int]: + """Greedy search for a single utterance. + Args: + log_probs: + A 3-D tensor of shape (1, T, vocab_size) + Returns: + Return the decoded result. + """ + assert log_probs.ndim == 3, log_probs.shape + assert log_probs.shape[0] == 1, log_probs.shape + + max_indexes = log_probs[0].argmax(dim=1) + unique_indexes = torch.unique_consecutive(max_indexes) + + blank_id = 0 + unique_indexes = unique_indexes[unique_indexes != blank_id] + return unique_indexes.tolist() + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel(model_filename=args.model_filename) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + hyp = [] + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + log_probs = model(frames) + + hyp += greedy_search(log_probs) + + # To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes + id2token = {} + with open(args.tokens, encoding="utf-8") as f: + for line in f: + token, idx = line.split() + if token[:3] == "<0x" and token[-1] == ">": + token = int(token[1:-1], base=16) + assert 0 <= token < 256, token + token = token.to_bytes(1, byteorder="little") + else: + token = token.encode(encoding="utf-8") + + id2token[int(idx)] = token + + text = b"" + for i in hyp: + text += id2token[i] + text = text.decode(encoding="utf-8") + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index e62491444..e7c4f40ee 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -326,7 +326,7 @@ class OnnxModel: A 3-D tensor of shape (N, T, C) Returns: Return a 3-D tensor of shape (N, T', joiner_dim) where - T' is usually equal to ((T-7)//2+1)//2 + T' is usually equal to ((T-7)//2-3)//2 """ encoder_input, encoder_output_names = self._build_encoder_input_output(x) diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 120000 index 000000000..652346001 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py new file mode 120000 index 000000000..d623a8462 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file From 10a234709cb6b8aa5e99b1c18140b49db7b6faca Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 14 Dec 2023 11:26:37 +0800 Subject: [PATCH 064/216] bugs fixed (#1416) --- egs/aishell4/ASR/prepare.sh | 2 +- egs/alimeeting/ASR/prepare.sh | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index 361cc26ab..e8d9eb7b9 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -79,7 +79,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Process aishell4" if [ ! -f data/fbank/aishell4/.fbank.done ]; then mkdir -p data/fbank/aishell4 - lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4 + ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} touch data/fbank/aishell4/.fbank.done fi fi diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 1709733c7..c8fed658d 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -7,6 +7,7 @@ set -eou pipefail stage=-1 stop_stage=100 +perturb_speed=true # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -68,7 +69,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Process alimeeting" if [ ! -f data/fbank/alimeeting/.fbank.done ]; then mkdir -p data/fbank/alimeeting - lhotse prepare ali-meeting $dl_dir/alimeeting data/manifests/alimeeting + ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} fi fi From 702d4f59147a81ef7ba37c7863ae7bd258c743a9 Mon Sep 17 00:00:00 2001 From: TianHao Zhang <32243340+Zth9730@users.noreply.github.com> Date: Thu, 21 Dec 2023 14:42:33 +0800 Subject: [PATCH 065/216] Update prepare.sh (#1422) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix the bug in line 251: 1、 del the additional blank 2、correct the spell error of "new_vocab_size" --- egs/libriheavy/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh index af7e3c5b0..b0736c98b 100755 --- a/egs/libriheavy/ASR/prepare.sh +++ b/egs/libriheavy/ASR/prepare.sh @@ -248,7 +248,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts fi for vocab_size in ${vocab_sizes[@]}; do - new_vacab_size = $(($vocab_size + 256)) + new_vocab_size=$(($vocab_size + 256)) lang_dir=data/lang_punc_bpe_${new_vocab_size} mkdir -p $lang_dir From 79a42148dbcd98c42586f8386d91f6f4bb8f9979 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 23 Dec 2023 00:38:36 +0800 Subject: [PATCH 066/216] Add CI test to cover zipformer/train.py (#1424) --- .github/scripts/docker/Dockerfile | 59 +++++++++++++++ .github/scripts/docker/run.sh | 60 +++++++++++++++ .github/workflows/build-cpu-docker.yml | 75 +++++++++++++++++++ .github/workflows/train-librispeech.yml | 56 ++++++++++++++ egs/gigaspeech/ASR/zipformer/my_profile.py | 1 + egs/gigaspeech/ASR/zipformer/profile.py | 1 - .../{profile.py => my_profile.py} | 2 +- .../{profile.py => my_profile.py} | 2 +- .../{profile.py => my_profile.py} | 2 +- .../zipformer/{profile.py => my_profile.py} | 2 +- egs/tedlium3/ASR/zipformer/my_profile.py | 1 + egs/tedlium3/ASR/zipformer/profile.py | 1 - 12 files changed, 256 insertions(+), 6 deletions(-) create mode 100644 .github/scripts/docker/Dockerfile create mode 100755 .github/scripts/docker/run.sh create mode 100644 .github/workflows/build-cpu-docker.yml create mode 100644 .github/workflows/train-librispeech.yml create mode 120000 egs/gigaspeech/ASR/zipformer/my_profile.py delete mode 120000 egs/gigaspeech/ASR/zipformer/profile.py rename egs/librispeech/ASR/pruned_transducer_stateless/{profile.py => my_profile.py} (98%) rename egs/librispeech/ASR/pruned_transducer_stateless4/{profile.py => my_profile.py} (98%) rename egs/librispeech/ASR/pruned_transducer_stateless7/{profile.py => my_profile.py} (98%) rename egs/librispeech/ASR/zipformer/{profile.py => my_profile.py} (99%) create mode 120000 egs/tedlium3/ASR/zipformer/my_profile.py delete mode 120000 egs/tedlium3/ASR/zipformer/profile.py diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile new file mode 100644 index 000000000..55c3aa1b9 --- /dev/null +++ b/.github/scripts/docker/Dockerfile @@ -0,0 +1,59 @@ +ARG PYTHON_VERSION=3.8 +FROM python:${PYTHON_VERSION} + +ARG TORCHAUDIO_VERSION="0.13.0" +ARG TORCH_VERSION="1.13.0" +ARG K2_VERSION="1.24.4.dev20231220" +ARG KALDIFEAT_VERSION="1.25.3.dev20231221" + +ARG _K2_VERSION="${K2_VERSION}+cpu.torch${TORCH_VERSION}" +ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}" + +RUN apt-get update -y && \ + apt-get install -qq -y \ + ffmpeg \ + git \ + git-lfs \ + less \ + vim \ + && \ + apt-get clean && \ + rm -rf /var/cache/apt/archives /var/lib/apt/lists + + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${_K2_VERSION} +LABEL kaldifeat_version=${_KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +# Install dependencies +RUN pip install --no-cache-dir \ + torch==${TORCH_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/cpu/torch_stable.html \ + k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + six \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +# RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ +# cd /workspace/icefall && \ +# pip install --no-cache-dir -r requirements.txt +# +# ENV PYTHONPATH /workspace/icefall:$PYTHONPATH +# +# WORKDIR /workspace/icefall diff --git a/.github/scripts/docker/run.sh b/.github/scripts/docker/run.sh new file mode 100755 index 000000000..aeb80b330 --- /dev/null +++ b/.github/scripts/docker/run.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +set -ex + +cd /icefall +export PYTHONPATH=/icefall:$PYTHONPATH +python3 -c "import torch; print(torch.__file__)" +python3 -c "import torchaudio; print(torchaudio.__version__)" +python3 -c "import icefall; print(icefall.__file__)" + +cd egs/librispeech/ASR + +# We don't download the LM file since it is so large that it will +# cause OOM error for CI later. +mkdir -p download/lm +pushd download/lm +wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt +wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt +wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz +ls -lh +gunzip librispeech-lm-norm.txt.gz + +ls -lh +popd + +pushd download/ +wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2 +tar xf LibriSpeech.tar.bz2 +rm LibriSpeech.tar.bz2 + +cd LibriSpeech +ln -s train-clean-100 train-clean-360 +ln -s train-other-500 train-other-500 +popd + +mkdir -p data/manifests + +lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests +ls -lh data/manifests + +./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False +ls -lh data/fbank + +./prepare.sh --stage 5 --stop-stage 6 + +./zipformer/train.py \ + --world-size 1 \ + --num-epochs 1 \ + --start-epoch 1 \ + --use-fp16 0 \ + --exp-dir zipformer/exp-small \ + --causal 0 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 64,96,96,96,96,96 \ + --encoder-dim 32,64,64,64,64,64 \ + --encoder-unmasked-dim 32,32,32,32,32,32 \ + --base-lr 0.04 \ + --full-libri 0 \ + --enable-musan 0 \ + --max-duration 30 \ + --print-diagnostics 1 diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml new file mode 100644 index 000000000..f931f7d09 --- /dev/null +++ b/.github/workflows/build-cpu-docker.yml @@ -0,0 +1,75 @@ +name: build-cpu-docker +on: + workflow_dispatch: + +concurrency: + group: build-cpu-docker-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-cpu-docker: + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.8", "3.9", "3.10"] + torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + k2-version: ["1.24.4.dev20231220"] + kaldifeat-version: ["1.25.3.dev20231221"] + version: ["1.0"] + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + + - name: 'Login to GitHub Container Registry' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build docker Image + shell: bash + run: | + cd .github/scripts/docker + torch_version=${{ matrix.torch-version }} + if [[ $torch_version == 1.13.0 ]]; then + torchaudio_version=0.13.0 + elif [[ $torch_version == 2.0.0 ]]; then + torchaudio_version=2.0.1 + elif [[ $torch_version == 2.0.1 ]]; then + torchaudio_version=2.0.2 + else + torchaudio_version=$torch_version + fi + echo "torch_version: $torch_version" + echo "torchaudio_version: $torchaudio_version" + + version=${{ matrix.version }} + + tag=ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version + echo "tag: $tag" + + docker build \ + -t $tag \ + --build-arg PYTHON_VERSION=${{ matrix.python-version }} \ + --build-arg TORCH_VERSION=$torch_version \ + --build-arg TORCHAUDIO_VERSION=$torchaudio_version \ + --build-arg K2_VERSION=${{ matrix.k2-version }} \ + --build-arg KALDIFEAT_VERSION=${{ matrix.kaldifeat-version }} \ + . + + docker image ls + docker push $tag diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml new file mode 100644 index 000000000..7c9a28f03 --- /dev/null +++ b/.github/workflows/train-librispeech.yml @@ -0,0 +1,56 @@ +name: train librispeech +on: + push: + branches: + - master + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: train-librispeech-${{ github.ref }} + cancel-in-progress: true + +jobs: + train-librispeech: + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.8", "3.9", "3.10"] + torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + k2-version: ["1.24.4.dev20231220"] + kaldifeat-version: ["1.25.3.dev20231221"] + version: ["1.0"] + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run the build process with Docker + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + ls -lh /icefall + + /icefall/.github/scripts/docker/run.sh diff --git a/egs/gigaspeech/ASR/zipformer/my_profile.py b/egs/gigaspeech/ASR/zipformer/my_profile.py new file mode 120000 index 000000000..3a90b2628 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/my_profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/profile.py b/egs/gigaspeech/ASR/zipformer/profile.py deleted file mode 120000 index c93adbd14..000000000 --- a/egs/gigaspeech/ASR/zipformer/profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless/profile.py rename to egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py index 09e4a7af4..b844ba613 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py @@ -17,7 +17,7 @@ # limitations under the License. """ -Usage: ./pruned_transducer_stateless/profile.py +Usage: ./pruned_transducer_stateless/my_profile.py """ import argparse diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless4/profile.py rename to egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py index 252bdf060..4bf773918 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py @@ -17,7 +17,7 @@ # limitations under the License. """ -Usage: ./pruned_transducer_stateless4/profile.py +Usage: ./pruned_transducer_stateless4/my_profile.py """ import argparse diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless7/profile.py rename to egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py index 0d308e966..5a068b3b6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py @@ -17,7 +17,7 @@ # limitations under the License. """ -Usage: ./pruned_transducer_stateless7/profile.py +Usage: ./pruned_transducer_stateless7/my_profile.py """ import argparse diff --git a/egs/librispeech/ASR/zipformer/profile.py b/egs/librispeech/ASR/zipformer/my_profile.py similarity index 99% rename from egs/librispeech/ASR/zipformer/profile.py rename to egs/librispeech/ASR/zipformer/my_profile.py index 57f44a90a..ca20956fb 100755 --- a/egs/librispeech/ASR/zipformer/profile.py +++ b/egs/librispeech/ASR/zipformer/my_profile.py @@ -17,7 +17,7 @@ # limitations under the License. """ -Usage: ./zipformer/profile.py +Usage: ./zipformer/my_profile.py """ import argparse diff --git a/egs/tedlium3/ASR/zipformer/my_profile.py b/egs/tedlium3/ASR/zipformer/my_profile.py new file mode 120000 index 000000000..3a90b2628 --- /dev/null +++ b/egs/tedlium3/ASR/zipformer/my_profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/profile.py b/egs/tedlium3/ASR/zipformer/profile.py deleted file mode 120000 index c93adbd14..000000000 --- a/egs/tedlium3/ASR/zipformer/profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file From e5bb1ae86cc750a51626e9afcb973cc03fa72f86 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 24 Dec 2023 13:40:33 +0800 Subject: [PATCH 067/216] Use the CPU docker in CI to simplify the test code (#1427) --- .github/scripts/docker/Dockerfile | 21 ++-- .github/workflows/build-cpu-docker.yml | 8 +- .github/workflows/test.yml | 139 +++++++++--------------- .github/workflows/train-librispeech.yml | 8 +- 4 files changed, 72 insertions(+), 104 deletions(-) diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index 55c3aa1b9..bbf978d26 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -14,6 +14,7 @@ RUN apt-get update -y && \ ffmpeg \ git \ git-lfs \ + graphviz \ less \ vim \ && \ @@ -32,23 +33,23 @@ RUN pip install --no-cache-dir \ k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ + dill \ + graphviz \ kaldi_native_io \ kaldialign \ kaldifst \ kaldilm \ - sentencepiece>=0.1.96 \ - tensorboard \ - typeguard \ - dill \ - onnx \ - onnxruntime \ - onnxmltools \ - six \ + matplotlib \ multi_quantization \ - typeguard \ numpy \ + onnx \ + onnxmltools \ + onnxruntime \ pytest \ - graphviz + sentencepiece>=0.1.96 \ + six \ + tensorboard \ + typeguard # RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ # cd /workspace/icefall && \ diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml index f931f7d09..b26cd2095 100644 --- a/.github/workflows/build-cpu-docker.yml +++ b/.github/workflows/build-cpu-docker.yml @@ -15,10 +15,10 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] k2-version: ["1.24.4.dev20231220"] kaldifeat-version: ["1.25.3.dev20231221"] - version: ["1.0"] + version: ["1.1"] steps: # refer to https://github.com/actions/checkout @@ -45,8 +45,12 @@ jobs: run: | cd .github/scripts/docker torch_version=${{ matrix.torch-version }} + + # see https://pytorch.org/audio/stable/installation.html#compatibility-matrix if [[ $torch_version == 1.13.0 ]]; then torchaudio_version=0.13.0 + elif [[ $torch_version == 1.13.1 ]]; then + torchaudio_version=0.13.1 elif [[ $torch_version == 2.0.0 ]]; then torchaudio_version=2.0.1 elif [[ $torch_version == 2.0.1 ]]; then diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 363556bb7..b3fd6f133 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,129 +1,94 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - name: test on: push: branches: - master + pull_request: branches: - master + workflow_dispatch: + concurrency: group: test-${{ github.ref }} cancel-in-progress: true jobs: test: + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.8"] - torch: ["1.13.0"] - torchaudio: ["0.13.0"] - k2-version: ["1.24.3.dev20230719"] - - fail-fast: false + python-version: ["3.8", "3.9", "3.10"] + torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + version: ["1.1"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - - name: Install libnsdfile and libsox - if: startsWith(matrix.os, 'ubuntu') - run: | - sudo apt update - sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg - sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all - - - name: Install Python dependencies - run: | - python3 -m pip install --upgrade pip pytest - # numpy 1.20.x does not support python 3.6 - pip install numpy==1.19 - pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - - pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.github.io/k2/cpu.html - pip install git+https://github.com/lhotse-speech/lhotse - # icefall requirements - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - pip install kaldifst - pip install onnxruntime matplotlib - pip install -r requirements.txt - - - name: Install graphviz - if: startsWith(matrix.os, 'ubuntu') + - name: Free space shell: bash run: | - python3 -m pip install -qq graphviz - sudo apt-get -qq install graphviz + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" - name: Run tests - if: startsWith(matrix.os, 'ubuntu') - run: | - ls -lh - export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH - echo $PYTHONPATH - pytest -v -s ./test - # runt tests for conformer ctc - cd egs/librispeech/ASR/conformer_ctc - pytest -v -s + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall - cd ../pruned_transducer_stateless - pytest -v -s + pytest -v -s ./test - cd ../pruned_transducer_stateless2 - pytest -v -s + # runt tests for conformer ctc + cd egs/librispeech/ASR/conformer_ctc + pytest -v -s - cd ../pruned_transducer_stateless3 - pytest -v -s + cd ../pruned_transducer_stateless + pytest -v -s - cd ../pruned_transducer_stateless4 - pytest -v -s + cd ../pruned_transducer_stateless2 + pytest -v -s - echo $PYTHONPATH - cd ../pruned_transducer_stateless7 - pytest -v -s + cd ../pruned_transducer_stateless3 + pytest -v -s - cd ../transducer_stateless - pytest -v -s + cd ../pruned_transducer_stateless4 + pytest -v -s - # cd ../transducer - # pytest -v -s + echo $PYTHONPATH + cd ../pruned_transducer_stateless7 + pytest -v -s - cd ../transducer_stateless2 - pytest -v -s + cd ../transducer_stateless + pytest -v -s - cd ../transducer_lstm - pytest -v -s + # cd ../transducer + # pytest -v -s - cd ../zipformer - pytest -v -s + cd ../transducer_stateless2 + pytest -v -s + + cd ../transducer_lstm + pytest -v -s + + cd ../zipformer + pytest -v -s - uses: actions/upload-artifact@v2 with: diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml index 7c9a28f03..53a2d5843 100644 --- a/.github/workflows/train-librispeech.yml +++ b/.github/workflows/train-librispeech.yml @@ -23,10 +23,8 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - k2-version: ["1.24.4.dev20231220"] - kaldifeat-version: ["1.25.3.dev20231221"] - version: ["1.0"] + torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + version: ["1.1"] steps: # refer to https://github.com/actions/checkout @@ -43,7 +41,7 @@ jobs: echo "pwd: $PWD" echo "github.workspace ${{ github.workspace }}" - - name: Run the build process with Docker + - name: Test zipformer/train.py with LibriSpeech uses: addnab/docker-run-action@v3 with: image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} From c855a58cfd8628dfe4ef2ffc0ac169d84a8ac0c5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 25 Dec 2023 19:41:09 +0800 Subject: [PATCH 068/216] Generate the dependency matrix by code for GitHub Actions (#1431) --- .github/scripts/docker/Dockerfile | 2 + .../scripts/docker/generate_build_matrix.py | 79 ++++++++ .../{docker => librispeech/ASR}/run.sh | 11 +- .github/scripts/yesno/ASR/run.sh | 86 +++++++++ .github/workflows/build-cpu-docker.yml | 42 ++-- .github/workflows/run-yesno-recipe.yml | 182 ++++-------------- .github/workflows/test.yml | 27 ++- .github/workflows/train-librispeech.yml | 33 +++- 8 files changed, 279 insertions(+), 183 deletions(-) create mode 100755 .github/scripts/docker/generate_build_matrix.py rename .github/scripts/{docker => librispeech/ASR}/run.sh (87%) create mode 100755 .github/scripts/yesno/ASR/run.sh diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index bbf978d26..f75d74854 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -31,10 +31,12 @@ LABEL github_repo="https://github.com/k2-fsa/icefall" RUN pip install --no-cache-dir \ torch==${TORCH_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/cpu/torch_stable.html \ k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \ + \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ dill \ graphviz \ + kaldi-decoder \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py new file mode 100755 index 000000000..4e494d810 --- /dev/null +++ b/.github/scripts/docker/generate_build_matrix.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + + +import json + + +def version_gt(a, b): + a_major, a_minor = a.split(".")[:2] + b_major, b_minor = b.split(".")[:2] + if a_major > b_major: + return True + + if a_major == b_major and a_minor > b_minor: + return True + + return False + + +def version_ge(a, b): + a_major, a_minor = a.split(".")[:2] + b_major, b_minor = b.split(".")[:2] + if a_major > b_major: + return True + + if a_major == b_major and a_minor >= b_minor: + return True + + return False + + +def get_torchaudio_version(torch_version): + if torch_version == "1.13.0": + return "0.13.0" + elif torch_version == "1.13.1": + return "0.13.1" + elif torch_version == "2.0.0": + return "2.0.1" + elif torch_version == "2.0.1": + return "2.0.2" + else: + return torch_version + + +def get_matrix(): + k2_version = "1.24.4.dev20231220" + kaldifeat_version = "1.25.3.dev20231221" + version = "1.1" + python_version = ["3.8", "3.9", "3.10", "3.11"] + torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + + matrix = [] + for p in python_version: + for t in torch_version: + # torchaudio <= 1.13.x supports only python <= 3.10 + + if version_gt(p, "3.10") and not version_gt(t, "2.0"): + continue + + matrix.append( + { + "k2-version": k2_version, + "kaldifeat-version": kaldifeat_version, + "version": version, + "python-version": p, + "torch-version": t, + "torchaudio-version": get_torchaudio_version(t), + } + ) + return matrix + + +def main(): + matrix = get_matrix() + print(json.dumps({"include": matrix})) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/docker/run.sh b/.github/scripts/librispeech/ASR/run.sh similarity index 87% rename from .github/scripts/docker/run.sh rename to .github/scripts/librispeech/ASR/run.sh index aeb80b330..641d59458 100755 --- a/.github/scripts/docker/run.sh +++ b/.github/scripts/librispeech/ASR/run.sh @@ -1,11 +1,12 @@ #!/usr/bin/env bash + set -ex -cd /icefall -export PYTHONPATH=/icefall:$PYTHONPATH -python3 -c "import torch; print(torch.__file__)" -python3 -c "import torchaudio; print(torchaudio.__version__)" -python3 -c "import icefall; print(icefall.__file__)" +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} cd egs/librispeech/ASR diff --git a/.github/scripts/yesno/ASR/run.sh b/.github/scripts/yesno/ASR/run.sh new file mode 100755 index 000000000..05c8fbac9 --- /dev/null +++ b/.github/scripts/yesno/ASR/run.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/yesno/ASR + +log "data preparation" +./prepare.sh + +log "training" +python3 ./tdnn/train.py + +log "decoding" +python3 ./tdnn/decode.py + +log "export to pretrained.pt" + +python3 ./tdnn/export.py --epoch 14 --avg 2 + +python3 ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +log "Test exporting to torchscript" +python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + +python3 ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +log "Test exporting to onnx" +python3 ./tdnn/export_onnx.py --epoch 14 --avg 2 + +log "Test float32 model" +python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +log "Test int8 model" +python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +log "Test decoding with H" +python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + +python3 ./tdnn/jit_pretrained_decode_with_H.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --H ./data/lang_phone/H.fst \ + --tokens ./data/lang_phone/tokens.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +log "Test decoding with HL" +python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + +python3 ./tdnn/jit_pretrained_decode_with_HL.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HL ./data/lang_phone/HL.fst \ + --words ./data/lang_phone/words.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +log "Show generated files" +ls -lh tdnn/exp +ls -lh data/lang_phone diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml index b26cd2095..c5d5aaeb6 100644 --- a/.github/workflows/build-cpu-docker.yml +++ b/.github/workflows/build-cpu-docker.yml @@ -7,18 +7,31 @@ concurrency: cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" build-cpu-docker: + needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: - os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - k2-version: ["1.24.4.dev20231220"] - kaldifeat-version: ["1.25.3.dev20231221"] - version: ["1.1"] + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: # refer to https://github.com/actions/checkout @@ -45,25 +58,14 @@ jobs: run: | cd .github/scripts/docker torch_version=${{ matrix.torch-version }} + torchaudio_version=${{ matrix.torchaudio-version }} - # see https://pytorch.org/audio/stable/installation.html#compatibility-matrix - if [[ $torch_version == 1.13.0 ]]; then - torchaudio_version=0.13.0 - elif [[ $torch_version == 1.13.1 ]]; then - torchaudio_version=0.13.1 - elif [[ $torch_version == 2.0.0 ]]; then - torchaudio_version=2.0.1 - elif [[ $torch_version == 2.0.1 ]]; then - torchaudio_version=2.0.2 - else - torchaudio_version=$torch_version - fi echo "torch_version: $torch_version" echo "torchaudio_version: $torchaudio_version" version=${{ matrix.version }} - tag=ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version + tag=ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version echo "tag: $tag" docker build \ diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 9ac848535..a99811815 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -20,166 +20,60 @@ on: push: branches: - master + - refactor-ci + pull_request: branches: - master + workflow_dispatch: + concurrency: group: run-yesno-recipe-${{ github.ref }} cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" run-yesno-recipe: - runs-on: ${{ matrix.os }} + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest strategy: - matrix: - # os: [ubuntu-latest, macos-10.15] - # TODO: enable macOS for CPU testing - os: [ubuntu-latest] - python-version: [3.8] fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - name: Run the yesno recipe + uses: addnab/docker-run-action@v3 with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall - - name: Install libnsdfile and libsox - if: startsWith(matrix.os, 'ubuntu') - run: | - sudo apt update - sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg - sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - pip install --no-deps --force-reinstall k2==1.24.4.dev20231021+cpu.torch1.13.1 -f https://k2-fsa.github.io/k2/cpu.html - pip install kaldifeat==1.25.1.dev20231022+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html - - - name: Run yesno recipe - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - ./prepare.sh - python3 ./tdnn/train.py - python3 ./tdnn/decode.py - - - name: Test exporting to pretrained.pt - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export.py --epoch 14 --avg 2 - - python3 ./tdnn/pretrained.py \ - --checkpoint ./tdnn/exp/pretrained.pt \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - - - name: Test exporting to torchscript - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 - - python3 ./tdnn/jit_pretrained.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - - - name: Test exporting to onnx - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export_onnx.py --epoch 14 --avg 2 - - echo "Test float32 model" - python3 ./tdnn/onnx_pretrained.py \ - --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - - - echo "Test int8 model" - python3 ./tdnn/onnx_pretrained.py \ - --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - - - name: Test decoding with H - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 - - python3 ./tdnn/jit_pretrained_decode_with_H.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --H ./data/lang_phone/H.fst \ - --tokens ./data/lang_phone/tokens.txt \ - ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ - ./download/waves_yesno/0_0_1_0_0_1_1_1.wav - - - name: Test decoding with HL - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 - - python3 ./tdnn/jit_pretrained_decode_with_HL.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --HL ./data/lang_phone/HL.fst \ - --words ./data/lang_phone/words.txt \ - ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ - ./download/waves_yesno/0_0_1_0_0_1_1_1.wav - - - name: Show generated files - shell: bash - working-directory: ${{github.workspace}} - run: | - cd egs/yesno/ASR - ls -lh tdnn/exp - ls -lh data/lang_phone + .github/scripts/yesno/ASR/run.sh diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b3fd6f133..659681b37 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,16 +16,31 @@ concurrency: cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" test: + needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: - os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - version: ["1.1"] + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - uses: actions/checkout@v4 @@ -44,7 +59,7 @@ jobs: - name: Run tests uses: addnab/docker-run-action@v3 with: - image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} options: | --volume ${{ github.workspace }}/:/icefall shell: bash diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml index 53a2d5843..79002a881 100644 --- a/.github/workflows/train-librispeech.yml +++ b/.github/workflows/train-librispeech.yml @@ -15,16 +15,31 @@ concurrency: cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" train-librispeech: + needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: - os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - version: ["1.1"] + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: # refer to https://github.com/actions/checkout @@ -44,11 +59,13 @@ jobs: - name: Test zipformer/train.py with LibriSpeech uses: addnab/docker-run-action@v3 with: - image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} options: | --volume ${{ github.workspace }}/:/icefall shell: bash run: | - ls -lh /icefall + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall - /icefall/.github/scripts/docker/run.sh + .github/scripts/librispeech/ASR/run.sh From ddd71313179a1565ea4d9e2e37546c3ef6b98d90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ali=20Haznedaro=C4=9Flu?= <53865510+ahazned@users.noreply.github.com> Date: Mon, 25 Dec 2023 14:44:07 +0300 Subject: [PATCH 069/216] Update TTS export-onnx.py scripts for handling variable token counts (#1430) --- egs/ljspeech/TTS/vits/export-onnx.py | 6 +++++- egs/vctk/TTS/vits/export-onnx.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 36a9de27f..f82f9dbe9 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -149,6 +149,7 @@ class OnnxModel(nn.Module): def export_model_onnx( model: nn.Module, model_filename: str, + vocab_size: int, opset_version: int = 11, ) -> None: """Export the given generator model to ONNX format. @@ -165,10 +166,12 @@ def export_model_onnx( The VITS generator. model_filename: The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. opset_version: The opset version to use. """ - tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) noise_scale_dur = torch.tensor([1], dtype=torch.float32) @@ -244,6 +247,7 @@ def main(): export_model_onnx( model, model_filename, + params.vocab_size, opset_version=opset_version, ) logging.info(f"Exported generator to {model_filename}") diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index 667ac284b..80d155626 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -159,6 +159,7 @@ class OnnxModel(nn.Module): def export_model_onnx( model: nn.Module, model_filename: str, + vocab_size: int, opset_version: int = 11, ) -> None: """Export the given generator model to ONNX format. @@ -175,10 +176,12 @@ def export_model_onnx( The VITS generator. model_filename: The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. opset_version: The opset version to use. """ - tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) noise_scale_dur = torch.tensor([1], dtype=torch.float32) @@ -261,6 +264,7 @@ def main(): export_model_onnx( model, model_filename, + params.vocab_size, opset_version=opset_version, ) logging.info(f"Exported generator to {model_filename}") From 835a92eba51a939c6b4a069a53cc1e3ddeabd9a5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 25 Dec 2023 20:23:56 +0800 Subject: [PATCH 070/216] Add doc about how to use the CPU-only docker images (#1432) --- docs/source/docker/intro.rst | 46 ++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index 9ead0df00..cbd300d9b 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -20,7 +20,11 @@ We describe the following items in this section: View available tags =================== -You can use the following command to view available tags: +CUDA-enabled docker images +-------------------------- + +You can use the following command to view available tags for CUDA-enabled +docker images: .. code-block:: bash @@ -43,8 +47,25 @@ which will give you something like below: Please select an appropriate combination of `torch`_ and CUDA. -Download a docker image -======================= +CPU-only docker images +---------------------- + +To view CPU-only docker images, please visit ``_ +for available tags. + +You can select different combinations of ``Python`` and ``torch``. For instance, +to select ``Python 3.8`` and ``torch 2.1.2``, you can use the following tag + +.. code-block:: bash + + cpu-py3.8-torch2.1.2-v1.1 + +where ``v1.1`` is the current version of the docker image. You may see +``ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.2`` or some other versions. +We recommend that you always use the latest version. + +Download a docker image (CUDA) +============================== Suppose that you select the tag ``torch1.13.0-cuda11.6``, you can use the following command to download it: @@ -53,6 +74,16 @@ the following command to download it: sudo docker image pull k2fsa/icefall:torch1.13.0-cuda11.6 +Download a docker image (CPU) +============================== + +Suppose that you select the tag ``cpu-py3.8-torch2.1.2-v1.1``, you can use +the following command to download it: + +.. code-block:: bash + + sudo docker pull ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.1 + Run a docker image with GPU =========================== @@ -65,7 +96,7 @@ Run a docker image with CPU .. code-block:: bash - sudo docker run --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash + sudo docker run --rm -it ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.1 /bin/bash Run yesno within a docker container =================================== @@ -74,8 +105,13 @@ After starting the container, the following interface is presented: .. code-block:: bash + # GPU-enabled docker root@60c947eac59c:/workspace/icefall# + # CPU-only docker + root@60c947eac59c:# mkdir /workspace; git clone https://github.com/k2-fsa/icefall + root@60c947eac59c:# export PYTHONPATH=/workspace/icefall:$PYTHONPATH + It shows the current user is ``root`` and the current working directory is ``/workspace/icefall``. @@ -107,7 +143,7 @@ to switch to the ``yesno`` recipe and run .. hint:: - If you are running without GPU, it may report the following error: + If you are running without GPU with a GPU-enabled docker, it may report the following error: .. code-block:: bash From db52fe2349df0e07e931accb0cf1e63fec389fb7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Dec 2023 20:29:43 +0800 Subject: [PATCH 071/216] Refactor CI test for aishell (#1435) --- .github/scripts/aishell/ASR/run.sh | 274 ++++++++++++++++++ .github/scripts/docker/Dockerfile | 1 + .../scripts/docker/generate_build_matrix.py | 2 +- ...pruned-transducer-stateless3-2022-06-20.sh | 87 ------ .../run-aishell-zipformer-2023-10-24.sh | 103 ------- ...transducer-stateless-modified-2-aishell.sh | 48 --- ...d-transducer-stateless-modified-aishell.sh | 48 --- .github/workflows/aishell.yml | 81 ++++++ .github/workflows/run-aishell-2022-06-20.yml | 123 -------- .../run-aishell-zipformer-2023-10-24.yml | 95 ------ ...ransducer-stateless-modified-2-aishell.yml | 80 ----- ...-transducer-stateless-modified-aishell.yml | 80 ----- .../{run-yesno-recipe.yml => yesno.yml} | 23 +- 13 files changed, 360 insertions(+), 685 deletions(-) create mode 100755 .github/scripts/aishell/ASR/run.sh delete mode 100755 .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh delete mode 100755 .github/scripts/run-aishell-zipformer-2023-10-24.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh create mode 100644 .github/workflows/aishell.yml delete mode 100644 .github/workflows/run-aishell-2022-06-20.yml delete mode 100644 .github/workflows/run-aishell-zipformer-2023-10-24.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml rename .github/workflows/{run-yesno-recipe.yml => yesno.yml} (69%) diff --git a/.github/scripts/aishell/ASR/run.sh b/.github/scripts/aishell/ASR/run.sh new file mode 100755 index 000000000..4d912fa76 --- /dev/null +++ b/.github/scripts/aishell/ASR/run.sh @@ -0,0 +1,274 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell/ASR + +function download_test_dev_manifests() { + git lfs install + + fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests + log "Downloading pre-commputed fbank from $fbank_url" + + git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests + ln -s $PWD/aishell-test-dev-manifests/data . +} + +function test_transducer_stateless3_2022_06_20() { + repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 + log "Downloading pre-trained model from $repo_url" + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt pretrained.pt + popd + + log "test greedy_search with pretrained.py" + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless3/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + + log "test beam search with pretrained.py" + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless3/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + + echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" + echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" + if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless3/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_char data/ + + ls -lh data + ls -lh pruned_transducer_stateless3/exp + + log "Decoding test and dev" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless3/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless3/exp + done + + rm pruned_transducer_stateless3/exp/*.pt + fi + + rm -rf $repo +} + +function test_zipformer_large_2023_10_24() { + log "CI testing large model" + repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/ + log "Downloading pre-trained model from $repo_url" + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +function test_zipformer_2023_10_24() { + repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/ + log "Downloading pre-trained model from $repo_url" + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + + for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +function test_zipformer_small_2023_10_24() { + log "CI testing small model" + repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/ + log "Downloading pre-trained model from $repo_url" + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + + for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +function test_transducer_stateless_modified_2022_03_01() { + repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_modified/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + + for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_modified/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +function test_transducer_stateless_modified_2_2022_03_01() { + repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_modified-2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + + for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_modified-2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +download_test_dev_manifests +test_transducer_stateless3_2022_06_20 +test_zipformer_large_2023_10_24 +test_zipformer_2023_10_24 +test_zipformer_small_2023_10_24 +test_transducer_stateless_modified_2022_03_01 +test_transducer_stateless_modified_2_2022_03_01 + +ls -lh diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index f75d74854..f6a088af1 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -16,6 +16,7 @@ RUN apt-get update -y && \ git-lfs \ graphviz \ less \ + tree \ vim \ && \ apt-get clean && \ diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 4e494d810..bdde97647 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20231220" kaldifeat_version = "1.25.3.dev20231221" - version = "1.1" + version = "1.2" python_version = ["3.8", "3.9", "3.10", "3.11"] torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh deleted file mode 100755 index c3640cfde..000000000 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/aishell/ASR - -git lfs install - -fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests -log "Downloading pre-commputed fbank from $fbank_url" - -git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests -ln -s $PWD/aishell-test-dev-manifests/data . - -repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless3/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless3/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless3/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_char data/ - - ls -lh data - ls -lh pruned_transducer_stateless3/exp - - log "Decoding test and dev" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless3/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless3/exp - done - - rm pruned_transducer_stateless3/exp/*.pt -fi diff --git a/.github/scripts/run-aishell-zipformer-2023-10-24.sh b/.github/scripts/run-aishell-zipformer-2023-10-24.sh deleted file mode 100755 index 865e29799..000000000 --- a/.github/scripts/run-aishell-zipformer-2023-10-24.sh +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/aishell/ASR - -git lfs install - -fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests -log "Downloading pre-commputed fbank from $fbank_url" - -git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests -ln -s $PWD/aishell-test-dev-manifests/data . - -log "=======================" -log "CI testing large model" -repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/ -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for method in modified_beam_search greedy_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --method $method \ - --context-size 1 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_char/tokens.txt \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -log "=======================" -log "CI testing medium model" -repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/ -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - - -for method in modified_beam_search greedy_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --method $method \ - --context-size 1 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_char/tokens.txt \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - - -log "=======================" -log "CI testing small model" -repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/ -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - - -for method in modified_beam_search greedy_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --method $method \ - --context-size 1 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_char/tokens.txt \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,768,768,768,768 \ - --encoder-dim 192,256,256,256,256,256 \ - --encoder-unmasked-dim 192,192,192,192,192,192 \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh deleted file mode 100755 index 0644d9be0..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/aishell/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless_modified-2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -for method in modified_beam_search beam_search; do - log "$method" - - ./transducer_stateless_modified-2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh deleted file mode 100755 index 79fb64311..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/aishell/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless_modified/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -for method in modified_beam_search beam_search; do - log "$method" - - ./transducer_stateless_modified/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml new file mode 100644 index 000000000..136e117bd --- /dev/null +++ b/.github/workflows/aishell.yml @@ -0,0 +1,81 @@ +name: aishell + +on: + push: + branches: + - master + + pull_request: + branches: + - master + + workflow_dispatch: + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: aishell-${{ github.ref }} + cancel-in-progress: true + +jobs: + generate_build_matrix: + if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'aishell' || github.event_name == 'schedule') + + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + aishell: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run aishell tests + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + .github/scripts/aishell/ASR/run.sh diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml deleted file mode 100644 index 53fcb2c03..000000000 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-aishell-2022-06-20 -# pruned RNN-T + reworked model with random combiner -# https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_aishell_2022_06_20-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_aishell_2022_06_20: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh - - - name: Display decoding results for aishell pruned_transducer_stateless3 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/aishell/ASR/ - tree ./pruned_transducer_stateless3/exp - - cd pruned_transducer_stateless3 - echo "results for pruned_transducer_stateless3" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 - - - name: Upload decoding results for aishell pruned_transducer_stateless3 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-06-20 - path: egs/aishell/ASR/pruned_transducer_stateless3/exp/ diff --git a/.github/workflows/run-aishell-zipformer-2023-10-24.yml b/.github/workflows/run-aishell-zipformer-2023-10-24.yml deleted file mode 100644 index f2fb44a5f..000000000 --- a/.github/workflows/run-aishell-zipformer-2023-10-24.yml +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2023 Zengrui Jin (Xiaomi Corp.) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-aishell-zipformer-2023-10-24 - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_aishell_zipformer_2023_10_24-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_aishell_zipformer_2023_10_24: - if: github.event.label.name == 'ready' || github.event.label.name == 'zipformer' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-aishell-zipformer-2023-10-24.sh - - \ No newline at end of file diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml deleted file mode 100644 index ce6d6f92d..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-trandsucer-stateless-modified-2-aishell - -on: - push: - branches: - - master - pull_request: - types: [labeled] - -concurrency: - group: run_pre_trained_transducer_stateless_modified_2_aishell-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless_modified_2_aishell: - if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml deleted file mode 100644 index f0cebd94a..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-trandsucer-stateless-modified-aishell - -on: - push: - branches: - - master - pull_request: - types: [labeled] - -concurrency: - group: run_pre_trained_transducer_stateless_modified_aishell-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless_modified_aishell: - if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/yesno.yml similarity index 69% rename from .github/workflows/run-yesno-recipe.yml rename to .github/workflows/yesno.yml index a99811815..182300dfa 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/yesno.yml @@ -1,26 +1,9 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-yesno-recipe +name: yesno on: push: branches: - master - - refactor-ci pull_request: branches: @@ -29,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: run-yesno-recipe-${{ github.ref }} + group: yesno-${{ github.ref }} cancel-in-progress: true jobs: @@ -50,7 +33,7 @@ jobs: python ./.github/scripts/docker/generate_build_matrix.py MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) echo "::set-output name=matrix::${MATRIX}" - run-yesno-recipe: + yesno: needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} runs-on: ubuntu-latest From 140e6381ad4699ce919705f59240fad58c0b1bb6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 27 Dec 2023 13:21:14 +0800 Subject: [PATCH 072/216] Refactor CI tests for librispeech (#1436) --- .github/scripts/aishell/ASR/run.sh | 73 +- .github/scripts/librispeech/ASR/run.sh | 1624 ++++++++++++++++- ...n-librispeech-conformer-ctc3-2022-11-28.sh | 122 -- ...-pruned-transducer-stateless-2022-03-12.sh | 77 - ...pruned-transducer-stateless2-2022-04-29.sh | 86 - ...pruned-transducer-stateless3-2022-04-29.sh | 85 - ...pruned-transducer-stateless3-2022-05-13.sh | 123 -- ...pruned-transducer-stateless5-2022-05-13.sh | 100 - ...pruned-transducer-stateless7-2022-11-11.sh | 106 -- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 150 -- ...transducer-stateless7-ctc-bs-2023-01-29.sh | 147 -- ...nsducer-stateless7-streaming-2022-12-29.sh | 148 -- ...pruned-transducer-stateless8-2022-11-14.sh | 115 -- ...pruned-transducer-stateless2-2022-06-26.sh | 101 - ...rispeech-streaming-zipformer-2023-05-18.sh | 116 -- ...speech-transducer-stateless2-2022-04-19.sh | 77 - .../run-librispeech-zipformer-2023-05-18.sh | 94 - ...un-librispeech-zipformer-ctc-2023-06-14.sh | 117 -- ...un-librispeech-zipformer-mmi-2022-12-08.sh | 102 -- .github/scripts/run-pre-trained-ctc.sh | 240 --- ...d-transducer-stateless-librispeech-100h.sh | 77 - ...d-transducer-stateless-librispeech-960h.sh | 77 - .../run-pre-trained-transducer-stateless.sh | 77 - .github/scripts/run-pre-trained-transducer.sh | 33 - .github/workflows/aishell.yml | 11 +- ...{train-librispeech.yml => librispeech.yml} | 6 +- .../workflows/run-librispeech-2022-03-12.yml | 159 -- .../workflows/run-librispeech-2022-04-29.yml | 185 -- .../workflows/run-librispeech-2022-05-13.yml | 159 -- .../run-librispeech-2022-11-11-stateless7.yml | 159 -- .../run-librispeech-2022-11-14-stateless8.yml | 159 -- ...-librispeech-2022-12-01-stateless7-ctc.yml | 163 -- ...n-librispeech-2022-12-08-zipformer-mmi.yml | 167 -- ...speech-2022-12-29-stateless7-streaming.yml | 172 -- ...brispeech-2023-01-29-stateless7-ctc-bs.yml | 163 -- ...-librispeech-conformer-ctc3-2022-11-28.yml | 155 -- ...runed-transducer-stateless3-2022-05-13.yml | 157 -- ...aming-transducer-stateless2-2022-06-26.yml | 159 -- ...ispeech-streaming-zipformer-2023-05-18.yml | 174 -- ...peech-transducer-stateless2-2022-04-19.yml | 159 -- .../run-librispeech-zipformer-2023-05-18.yml | 159 -- ...n-librispeech-zipformer-ctc-2023-06-14.yml | 155 -- .github/workflows/run-pretrained-ctc.yml | 87 - ...-transducer-stateless-librispeech-100h.yml | 158 -- ...r-stateless-librispeech-multi-datasets.yml | 158 -- .../run-pretrained-transducer-stateless.yml | 158 -- .../workflows/run-pretrained-transducer.yml | 80 - 47 files changed, 1658 insertions(+), 5671 deletions(-) delete mode 100755 .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh delete mode 100755 .github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh delete mode 100755 .github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh delete mode 100755 .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh delete mode 100755 .github/scripts/run-librispeech-zipformer-2023-05-18.sh delete mode 100755 .github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh delete mode 100755 .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh delete mode 100755 .github/scripts/run-pre-trained-ctc.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless.sh delete mode 100755 .github/scripts/run-pre-trained-transducer.sh rename .github/workflows/{train-librispeech.yml => librispeech.yml} (95%) delete mode 100644 .github/workflows/run-librispeech-2022-03-12.yml delete mode 100644 .github/workflows/run-librispeech-2022-04-29.yml delete mode 100644 .github/workflows/run-librispeech-2022-05-13.yml delete mode 100644 .github/workflows/run-librispeech-2022-11-11-stateless7.yml delete mode 100644 .github/workflows/run-librispeech-2022-11-14-stateless8.yml delete mode 100644 .github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml delete mode 100644 .github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml delete mode 100644 .github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml delete mode 100644 .github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml delete mode 100644 .github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml delete mode 100644 .github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml delete mode 100644 .github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml delete mode 100644 .github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml delete mode 100644 .github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml delete mode 100644 .github/workflows/run-librispeech-zipformer-2023-05-18.yml delete mode 100644 .github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml delete mode 100644 .github/workflows/run-pretrained-ctc.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless.yml delete mode 100644 .github/workflows/run-pretrained-transducer.yml diff --git a/.github/scripts/aishell/ASR/run.sh b/.github/scripts/aishell/ASR/run.sh index 4d912fa76..f150b6337 100755 --- a/.github/scripts/aishell/ASR/run.sh +++ b/.github/scripts/aishell/ASR/run.sh @@ -263,6 +263,76 @@ function test_transducer_stateless_modified_2_2022_03_01() { rm -rf $repo } +function test_conformer_ctc() { + repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + pushd $repo + + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "data/lang_char/H.fst" + git lfs pull --include "data/lang_char/HL.fst" + git lfs pull --include "data/lang_char/HLG.fst" + + popd + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + log "CTC decoding" + + log "Exporting model with torchscript" + + pushd $repo/exp + ln -s pretrained.pt epoch-99.pt + popd + + ./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_char/tokens.txt \ + --jit 1 + + ls -lh $repo/exp + + ls -lh $repo/data/lang_char + + log "Decoding with H on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_char/H.fst \ + --tokens $repo/data/lang_char/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "Decoding with HL on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_char/HL.fst \ + --words $repo/data/lang_char/words.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "Decoding with HLG on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_char/HLG.fst \ + --words $repo/data/lang_char/words.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + rm -rf $repo +} + download_test_dev_manifests test_transducer_stateless3_2022_06_20 test_zipformer_large_2023_10_24 @@ -270,5 +340,4 @@ test_zipformer_2023_10_24 test_zipformer_small_2023_10_24 test_transducer_stateless_modified_2022_03_01 test_transducer_stateless_modified_2_2022_03_01 - -ls -lh +# test_conformer_ctc # fails for torch 1.13.x and torch 2.0.x diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh index 641d59458..7e9bd8a47 100755 --- a/.github/scripts/librispeech/ASR/run.sh +++ b/.github/scripts/librispeech/ASR/run.sh @@ -10,52 +10,1594 @@ log() { cd egs/librispeech/ASR -# We don't download the LM file since it is so large that it will -# cause OOM error for CI later. -mkdir -p download/lm -pushd download/lm -wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt -wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt -wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz -ls -lh -gunzip librispeech-lm-norm.txt.gz +function prepare_data() { + # We don't download the LM file since it is so large that it will + # cause OOM error for CI later. + mkdir -p download/lm + pushd download/lm + wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt + wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt + wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz + ls -lh + gunzip librispeech-lm-norm.txt.gz -ls -lh -popd + ls -lh + popd -pushd download/ -wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2 -tar xf LibriSpeech.tar.bz2 -rm LibriSpeech.tar.bz2 + pushd download/ + wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2 + tar xf LibriSpeech.tar.bz2 + rm LibriSpeech.tar.bz2 -cd LibriSpeech -ln -s train-clean-100 train-clean-360 -ln -s train-other-500 train-other-500 -popd + cd LibriSpeech + ln -s train-clean-100 train-clean-360 + ln -s train-other-500 train-other-500 + popd -mkdir -p data/manifests + mkdir -p data/manifests -lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests -ls -lh data/manifests + lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests + ls -lh data/manifests -./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False -ls -lh data/fbank + ./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False + ls -lh data/fbank -./prepare.sh --stage 5 --stop-stage 6 + ./prepare.sh --stage 5 --stop-stage 6 +} -./zipformer/train.py \ - --world-size 1 \ - --num-epochs 1 \ - --start-epoch 1 \ - --use-fp16 0 \ - --exp-dir zipformer/exp-small \ - --causal 0 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 64,96,96,96,96,96 \ - --encoder-dim 32,64,64,64,64,64 \ - --encoder-unmasked-dim 32,32,32,32,32,32 \ - --base-lr 0.04 \ - --full-libri 0 \ - --enable-musan 0 \ - --max-duration 30 \ - --print-diagnostics 1 +function run_diagnostics() { + ./zipformer/train.py \ + --world-size 1 \ + --num-epochs 1 \ + --start-epoch 1 \ + --use-fp16 0 \ + --exp-dir zipformer/exp-small \ + --causal 0 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 64,96,96,96,96,96 \ + --encoder-dim 32,64,64,64,64,64 \ + --encoder-unmasked-dim 32,32,32,32,32,32 \ + --base-lr 0.04 \ + --full-libri 0 \ + --enable-musan 0 \ + --max-duration 30 \ + --print-diagnostics 1 +} + +function test_pruned_transducer_stateless_2022_03_12() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in fast_beam_search modified_beam_search beam_search; do + log "$method" + + ./pruned_transducer_stateless/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless2_2022_04_29() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt" + popd + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-38-avg-10.pt pretrained.pt + popd + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless3_2022_04_29() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt" + popd + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-25-avg-6.pt pretrained.pt + popd + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless3/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless3/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless5_2022_05_13() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-39-avg-7.pt pretrained.pt + popd + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless5/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless5/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + done + rm -rf $repo +} + +function test_pruned_transducer_stateless7_2022_11_11() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./pruned_transducer_stateless7/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless7/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless8_2022_11_14() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Export to torchscript model" + ./pruned_transducer_stateless8/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model false \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless8/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless8/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless7_ctc_2022_12_01() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lm/G_4_gram.pt" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless7_ctc/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_zipformer_mmi_2022_12_08() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/3gram.pt" + git lfs pull --include "data/lang_bpe_500/4gram.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./zipformer_mmi/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./zipformer_mmi/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + --lang-dir $repo/data/lang_bpe_500 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do + log "$method" + + ./zipformer_mmi/pretrained.py \ + --method $method \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_bpe_500 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless7_streaming_2022_12_29() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "exp/encoder_jit_trace.pt" + git lfs pull --include "exp/decoder_jit_trace.pt" + git lfs pull --include "exp/joiner_jit_trace.pt" + cd exp + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless7_streaming/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Export to torchscript model by torch.jit.trace()" + ./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 + + log "Decode with models exported by torch.jit.trace()" + + ./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + rm -rf $repo +} + +function test_pruned_transducer_stateless7_ctc_bs_2023_01_29() { + repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_conformer_ctc3_2022_11_27() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lm/G_4_gram.pt" + git lfs pull --include "exp/jit_trace.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Decode with models exported by torch.jit.trace()" + + for m in ctc-decoding 1best; do + ./conformer_ctc3/jit_pretrained.py \ + --model-filename $repo/exp/jit_trace.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + log "Export to torchscript model" + + ./conformer_ctc3/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --jit-trace 1 \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.trace()" + + for m in ctc-decoding 1best; do + ./conformer_ctc3/jit_pretrained.py \ + --model-filename $repo/exp/jit_trace.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for m in ctc-decoding 1best; do + ./conformer_ctc3/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_lstm_transducer_stateless2_2022_09_03() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + abs_repo=$(realpath $repo) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-iter-468000-avg-16.pt pretrained.pt + ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt + popd + + log "Test exporting with torch.jit.trace()" + + ./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --jit-trace 1 + + log "Decode with models exported by torch.jit.trace()" + + ./lstm_transducer_stateless2/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./lstm_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./lstm_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless3_2022_05_13() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt + ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt + popd + + + log "Export to torchscript model" + ./pruned_transducer_stateless3/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ./pruned_transducer_stateless3/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit-trace 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.trace()" + + ./pruned_transducer_stateless3/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless3/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_script.pt \ + --decoder-model-filename $repo/exp/decoder_jit_script.pt \ + --joiner-model-filename $repo/exp/joiner_jit_script.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless3/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless3/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + rm -rf $repo +} + +function test_streaming_pruned_transducer_stateless2_20220625() { + repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-24-avg-10.pt pretrained.pt + popd + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_streaming_zipformer_2023_05_17() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lang_bpe_500/tokens.txt" + git lfs pull --include "exp/jit_script_chunk_16_left_128.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./zipformer/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./zipformer/jit_pretrained_streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \ + $repo/test_wavs/1089-134686-0001.wav + + for method in greedy_search modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_zipformer_2023_05_18() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lang_bpe_500/tokens.txt" + git lfs pull --include "exp/jit_script.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./zipformer/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./zipformer/jit_pretrained.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --nn-model-filename $repo/exp/jit_script.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for method in greedy_search modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_transducer_stateless2_torchaudio_2022_04_19() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in fast_beam_search modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_zipformer_transducer_ctc_2023_06_13() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lang_bpe_500/tokens.txt" + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lm/G_4_gram.pt" + git lfs pull --include "exp/jit_script.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./zipformer/export.py \ + --exp-dir $repo/exp \ + --use-transducer 1 \ + --use-ctc 1 \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + for method in ctc-decoding 1best; do + ./zipformer/jit_pretrained_ctc.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --model-filename $repo/exp/jit_script.pt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --G $repo/data/lm/G_4_gram.pt \ + --method $method \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in ctc-decoding 1best; do + log "$method" + + ./zipformer/pretrained_ctc.py \ + --use-transducer 1 \ + --use-ctc 1 \ + --method $method \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --G $repo/data/lm/G_4_gram.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_100h_transducer_stateless_multi_datasets_bpe_500_2022_02_21() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_transducer_stateless_multi_datasets_bpe_500_2022_03_01() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_transducer_stateless_bpe_500_2022_02_07() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in fast_beam_search modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_zipformer_ctc_en_2023_10_02() { + repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + log "CTC greedy search" + + ./zipformer/onnx_pretrained_ctc.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "CTC H decoding" + + ./zipformer/onnx_pretrained_ctc_H.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + --H $repo/H.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "CTC HL decoding" + + ./zipformer/onnx_pretrained_ctc_HL.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HL $repo/HL.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "CTC HLG decoding" + + ./zipformer/onnx_pretrained_ctc_HLG.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HLG $repo/HLG.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + rm -rf $repo +} + +function test_conformer_ctc_jit_bpe_500_2021_11_09() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + pushd $repo + + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/L_disambig.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lang_bpe_500/lexicon.txt" + git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt" + git lfs pull --include "data/lang_bpe_500/tokens.txt" + git lfs pull --include "data/lang_bpe_500/words.txt" + git lfs pull --include "data/lm/G_3_gram.fst.txt" + + popd + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + log "CTC decoding" + + ./conformer_ctc/pretrained.py \ + --method ctc-decoding \ + --num-classes 500 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "HLG decoding" + + ./conformer_ctc/pretrained.py \ + --method 1best \ + --num-classes 500 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "CTC decoding on CPU with kaldi decoders using OpenFst" + + log "Exporting model with torchscript" + + pushd $repo/exp + ln -s pretrained.pt epoch-99.pt + popd + + ./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --jit 1 + + ls -lh $repo/exp + + + log "Generating H.fst, HL.fst" + + ./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt + + ls -lh $repo/data/lang_bpe_500 + + log "Decoding with H on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_bpe_500/H.fst \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Decoding with HL on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_bpe_500/HL.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Decoding with HLG on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + rm -rf $repo +} + +function test_transducer_bpe_500_2021_12_23() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + log "Beam search decoding" + + ./transducer/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + rm -rf $repo +} + +prepare_data +run_diagnostics +test_pruned_transducer_stateless_2022_03_12 +test_pruned_transducer_stateless2_2022_04_29 +test_pruned_transducer_stateless3_2022_04_29 +test_pruned_transducer_stateless5_2022_05_13 +test_pruned_transducer_stateless7_2022_11_11 +test_pruned_transducer_stateless8_2022_11_14 +test_pruned_transducer_stateless7_ctc_2022_12_01 +test_zipformer_mmi_2022_12_08 +test_pruned_transducer_stateless7_streaming_2022_12_29 +test_pruned_transducer_stateless7_ctc_bs_2023_01_29 +test_conformer_ctc3_2022_11_27 +test_lstm_transducer_stateless2_2022_09_03 +test_pruned_transducer_stateless3_2022_05_13 +test_streaming_pruned_transducer_stateless2_20220625 +test_streaming_zipformer_2023_05_17 +test_zipformer_2023_05_18 +test_transducer_stateless2_torchaudio_2022_04_19 +test_zipformer_transducer_ctc_2023_06_13 +test_100h_transducer_stateless_multi_datasets_bpe_500_2022_02_21 +test_transducer_stateless_multi_datasets_bpe_500_2022_03_01 +test_transducer_stateless_bpe_500_2022_02_07 +test_zipformer_ctc_en_2023_10_02 +# test_conformer_ctc_jit_bpe_500_2021_11_09 # failes for torch != 1.13.x and torch != 2.0.x +test_transducer_bpe_500_2021_12_23 diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh deleted file mode 100755 index f6fe8c9b2..000000000 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lm/G_4_gram.pt" -git lfs pull --include "exp/jit_trace.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Decode with models exported by torch.jit.trace()" - -for m in ctc-decoding 1best; do - ./conformer_ctc3/jit_pretrained.py \ - --model-filename $repo/exp/jit_trace.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -log "Export to torchscript model" - -./conformer_ctc3/export.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --jit-trace 1 \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.trace()" - -for m in ctc-decoding 1best; do - ./conformer_ctc3/jit_pretrained.py \ - --model-filename $repo/exp/jit_trace.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for m in ctc-decoding 1best; do - ./conformer_ctc3/pretrained.py \ - --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p conformer_ctc3/exp - ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh conformer_ctc3/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in ctc-decoding 1best; do - log "Decoding with $method" - ./conformer_ctc3/decode.py \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --exp-dir conformer_ctc3/exp/ \ - --max-duration $max_duration \ - --decoding-method $method \ - --lm-dir data/lm - done - - rm conformer_ctc3/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh deleted file mode 100755 index 412e3ad56..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in fast_beam_search modified_beam_search beam_search; do - log "$method" - - ./pruned_transducer_stateless/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless/exp - done - - rm pruned_transducer_stateless/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh deleted file mode 100755 index 243b669ed..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt" -popd - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-38-avg-10.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless2/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless2/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless2/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless2/exp - done - - rm pruned_transducer_stateless2/exp/*.pt - rm -r data/lang_bpe_500 -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh deleted file mode 100755 index 2d0f80304..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt" -popd - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-25-avg-6.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless3/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless3/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless3/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless3/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless3/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless3/exp - done - - rm pruned_transducer_stateless3/exp/*.pt - rm -r data/lang_bpe_500 -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh deleted file mode 100755 index 3d5814c48..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt -ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt -popd - - -log "Export to torchscript model" -./pruned_transducer_stateless3/export.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -./pruned_transducer_stateless3/export.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit-trace 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.trace()" - -./pruned_transducer_stateless3/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ - --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ - --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless3/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder_jit_script.pt \ - --decoder-model-filename $repo/exp/decoder_jit_script.pt \ - --joiner-model-filename $repo/exp/joiner_jit_script.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless3/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless3/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless3/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless3/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless3/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless3/exp - done - - rm pruned_transducer_stateless3/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh deleted file mode 100755 index 3d2442d54..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-39-avg-7.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless5/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --num-encoder-layers 18 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --encoder-dim 512 \ - --decoder-dim 512 \ - --joiner-dim 512 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless5/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav \ - --num-encoder-layers 18 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --encoder-dim 512 \ - --decoder-dim 512 \ - --joiner-dim 512 -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless5/exp - ln -s $PWD/$repo/exp/pretrained-epoch-39-avg-7.pt pruned_transducer_stateless5/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless5/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless5/decode.py \ - --decoding-method $method \ - --use-averaged-model 0 \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless5/exp \ - --num-encoder-layers 18 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --encoder-dim 512 \ - --decoder-dim 512 \ - --joiner-dim 512 - done - - rm pruned_transducer_stateless5/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh deleted file mode 100755 index 961dde4f4..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./pruned_transducer_stateless7/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless7/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless7/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless7/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless7/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless7/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless7/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless7/exp - done - - rm pruned_transducer_stateless7/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh deleted file mode 100755 index ba7139efb..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lm/G_4_gram.pt" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./pruned_transducer_stateless7_ctc/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless7_ctc/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --model-filename $repo/exp/cpu_jit.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless7_ctc/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless7_ctc/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ - --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless7_ctc/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless7_ctc/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless7_ctc/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless7_ctc/exp - done - - for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc/ctc_decode.py \ - --epoch 999 \ - --avg 1 \ - --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --max-duration $max_duration \ - --use-averaged-model 0 \ - --decoding-method $m \ - --hlg-scale 0.6 \ - --lm-dir data/lm - done - - rm pruned_transducer_stateless7_ctc/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh deleted file mode 100755 index 1ecbc4798..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./pruned_transducer_stateless7_ctc_bs/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ - --model-filename $repo/exp/cpu_jit.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ - --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" - -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless7_ctc_bs/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc_bs/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless7_ctc_bs/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless7_ctc_bs/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless7_ctc_bs/exp - done - - for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ - --epoch 999 \ - --avg 1 \ - --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --max-duration $max_duration \ - --use-averaged-model 0 \ - --decoding-method $m \ - --hlg-scale 0.6 - done - - rm pruned_transducer_stateless7_ctc_bs/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh deleted file mode 100755 index 37b192a57..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "exp/encoder_jit_trace.pt" -git lfs pull --include "exp/decoder_jit_trace.pt" -git lfs pull --include "exp/joiner_jit_trace.pt" -cd exp -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --decode-chunk-len 32 \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless7_streaming/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - --decode-chunk-len 32 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Export to torchscript model by torch.jit.trace()" -./pruned_transducer_stateless7_streaming/jit_trace_export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --decode-chunk-len 32 \ - --epoch 99 \ - --avg 1 - -log "Decode with models exported by torch.jit.trace()" - -./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ - --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ - --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ - --decode-chunk-len 32 \ - $repo/test_wavs/1089-134686-0001.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless7_streaming/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --decode-chunk-len 32 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless7_streaming/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --decode-chunk-len 32 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless7_streaming/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_streaming/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless7_streaming/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - num_decode_stream=200 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "decoding with $method" - - ./pruned_transducer_stateless7_streaming/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --decode-chunk-len 32 \ - --exp-dir pruned_transducer_stateless7_streaming/exp - done - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --decode-chunk-len 32 \ - --num-decode-streams $num_decode_stream - --exp-dir pruned_transducer_stateless7_streaming/exp - done - - rm pruned_transducer_stateless7_streaming/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh deleted file mode 100755 index 4f2bfac24..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless8/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Export to torchscript model" -./pruned_transducer_stateless8/export.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model false \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless8/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless8/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless8/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless8/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless8/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless8/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless8/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless8/exp - done - - rm pruned_transducer_stateless8/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh deleted file mode 100755 index 5cbdad16d..000000000 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-24-avg-10.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --simulate-streaming 1 \ - --causal-convolution 1 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --simulate-streaming 1 \ - --causal-convolution 1 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless2/exp - ln -s $PWD/$repo/exp/pretrained-epoch-24-avg-10.pt pruned_transducer_stateless2/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless2/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Simulate streaming decoding with $method" - - ./pruned_transducer_stateless2/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless2/exp \ - --simulate-streaming 1 \ - --causal-convolution 1 - done - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Real streaming decoding with $method" - - ./pruned_transducer_stateless2/streaming_decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --num-decode-streams 100 \ - --exp-dir pruned_transducer_stateless2/exp \ - --left-context 32 \ - --decode-chunk-size 8 \ - --right-context 0 - done - - rm pruned_transducer_stateless2/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh b/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh deleted file mode 100755 index f4e2124b1..000000000 --- a/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "exp/jit_script_chunk_16_left_128.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./zipformer/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./zipformer/jit_pretrained_streaming.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \ - $repo/test_wavs/1089-134686-0001.wav - -for method in greedy_search modified_beam_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p zipformer/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh zipformer/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Simulated streaming decoding with $method" - - ./zipformer/decode.py \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir zipformer/exp - done - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Chunk-wise streaming decoding with $method" - - ./zipformer/streaming_decode.py \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir zipformer/exp - done - - rm zipformer/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh deleted file mode 100755 index ff77855a2..000000000 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in fast_beam_search modified_beam_search beam_search; do - log "$method" - - ./transducer_stateless2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p transducer_stateless2/exp - ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh transducer_stateless2/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./transducer_stateless2/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir transducer_stateless2/exp - done - - rm transducer_stateless2/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-zipformer-2023-05-18.sh b/.github/scripts/run-librispeech-zipformer-2023-05-18.sh deleted file mode 100755 index fb1a0149d..000000000 --- a/.github/scripts/run-librispeech-zipformer-2023-05-18.sh +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "exp/jit_script.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./zipformer/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./zipformer/jit_pretrained.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --nn-model-filename $repo/exp/jit_script.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for method in greedy_search modified_beam_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p zipformer/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh zipformer/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./zipformer/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir zipformer/exp - done - - rm zipformer/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh b/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh deleted file mode 100755 index 0026d2109..000000000 --- a/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lm/G_4_gram.pt" -git lfs pull --include "exp/jit_script.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./zipformer/export.py \ - --exp-dir $repo/exp \ - --use-transducer 1 \ - --use-ctc 1 \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -for method in ctc-decoding 1best; do - ./zipformer/jit_pretrained_ctc.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --model-filename $repo/exp/jit_script.pt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --G $repo/data/lm/G_4_gram.pt \ - --method $method \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in ctc-decoding 1best; do - log "$method" - - ./zipformer/pretrained_ctc.py \ - --use-transducer 1 \ - --use-ctc 1 \ - --method $method \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --G $repo/data/lm/G_4_gram.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p zipformer/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh zipformer/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in ctc-decoding 1best; do - log "Decoding with $method" - - ./zipformer/ctc_decode.py \ - --use-transducer 1 \ - --use-ctc 1 \ - --decoding-method $method \ - --nbest-scale 1.0 \ - --hlg-scale 0.6 \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir zipformer/exp - done - - rm zipformer/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh deleted file mode 100755 index c59921055..000000000 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/3gram.pt" -git lfs pull --include "data/lang_bpe_500/4gram.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./zipformer_mmi/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./zipformer_mmi/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - --lang-dir $repo/data/lang_bpe_500 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do - log "$method" - - ./zipformer_mmi/pretrained.py \ - --method $method \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_bpe_500 \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p zipformer_mmi/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer_mmi/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh zipformer_mmi/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do - log "Decoding with $method" - - ./zipformer_mmi/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --nbest-scale 1.2 \ - --hp-scale 1.0 \ - --max-duration $max_duration \ - --lang-dir $repo/data/lang_bpe_500 \ - --exp-dir zipformer_mmi/exp - done - - rm zipformer_mmi/exp/*.pt -fi diff --git a/.github/scripts/run-pre-trained-ctc.sh b/.github/scripts/run-pre-trained-ctc.sh deleted file mode 100755 index 7d6449c9a..000000000 --- a/.github/scripts/run-pre-trained-ctc.sh +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -pushd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -log "CTC greedy search" - -./zipformer/onnx_pretrained_ctc.py \ - --nn-model $repo/model.onnx \ - --tokens $repo/tokens.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "CTC H decoding" - -./zipformer/onnx_pretrained_ctc_H.py \ - --nn-model $repo/model.onnx \ - --tokens $repo/tokens.txt \ - --H $repo/H.fst \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "CTC HL decoding" - -./zipformer/onnx_pretrained_ctc_HL.py \ - --nn-model $repo/model.onnx \ - --words $repo/words.txt \ - --HL $repo/HL.fst \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "CTC HLG decoding" - -./zipformer/onnx_pretrained_ctc_HLG.py \ - --nn-model $repo/model.onnx \ - --words $repo/words.txt \ - --HLG $repo/HLG.fst \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -rm -rf $repo - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo - -git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/L_disambig.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lang_bpe_500/lexicon.txt" -git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt" -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "data/lang_bpe_500/words.txt" -git lfs pull --include "data/lm/G_3_gram.fst.txt" - -popd - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -log "CTC decoding" - -./conformer_ctc/pretrained.py \ - --method ctc-decoding \ - --num-classes 500 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "HLG decoding" - -./conformer_ctc/pretrained.py \ - --method 1best \ - --num-classes 500 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "CTC decoding on CPU with kaldi decoders using OpenFst" - -log "Exporting model with torchscript" - -pushd $repo/exp -ln -s pretrained.pt epoch-99.pt -popd - -./conformer_ctc/export.py \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --jit 1 - -ls -lh $repo/exp - - -log "Generating H.fst, HL.fst" - -./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt - -ls -lh $repo/data/lang_bpe_500 - -log "Decoding with H on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_H.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --H $repo/data/lang_bpe_500/H.fst \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Decoding with HL on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_HL.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --HL $repo/data/lang_bpe_500/HL.fst \ - --words $repo/data/lang_bpe_500/words.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Decoding with HLG on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_HLG.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --HLG $repo/data/lang_bpe_500/HLG.fst \ - --words $repo/data/lang_bpe_500/words.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -rm -rf $repo - -popd - -log "Test aishell" - -pushd egs/aishell/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo - -git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "data/lang_char/H.fst" -git lfs pull --include "data/lang_char/HL.fst" -git lfs pull --include "data/lang_char/HLG.fst" - -popd - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -log "CTC decoding" - -log "Exporting model with torchscript" - -pushd $repo/exp -ln -s pretrained.pt epoch-99.pt -popd - -./conformer_ctc/export.py \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_char/tokens.txt \ - --jit 1 - -ls -lh $repo/exp - -ls -lh $repo/data/lang_char - -log "Decoding with H on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_H.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --H $repo/data/lang_char/H.fst \ - --tokens $repo/data/lang_char/tokens.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "Decoding with HL on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_HL.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --HL $repo/data/lang_char/HL.fst \ - --words $repo/data/lang_char/words.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "Decoding with HLG on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_HLG.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --HLG $repo/data/lang_char/HLG.fst \ - --words $repo/data/lang_char/words.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -rm -rf $repo diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh deleted file mode 100755 index 7b686328d..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./transducer_stateless_multi_datasets/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p transducer_stateless_multi_datasets/exp - ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh transducer_stateless_multi_datasets/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./transducer_stateless_multi_datasets/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir transducer_stateless_multi_datasets/exp - done - - rm transducer_stateless_multi_datasets/exp/*.pt -fi diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh deleted file mode 100755 index a8eeeb514..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./transducer_stateless_multi_datasets/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p transducer_stateless_multi_datasets/exp - ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh transducer_stateless_multi_datasets/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./transducer_stateless_multi_datasets/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir transducer_stateless_multi_datasets/exp - done - - rm transducer_stateless_multi_datasets/exp/*.pt -fi diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh deleted file mode 100755 index 2e2360435..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in fast_beam_search modified_beam_search beam_search; do - log "$method" - - ./transducer_stateless/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p transducer_stateless/exp - ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh transducer_stateless/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./transducer_stateless/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir transducer_stateless/exp - done - - rm transducer_stateless/exp/*.pt -fi diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh deleted file mode 100755 index b865f8d13..000000000 --- a/.github/scripts/run-pre-trained-transducer.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -log "Beam search decoding" - -./transducer/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml index 136e117bd..8b0599fca 100644 --- a/.github/workflows/aishell.yml +++ b/.github/workflows/aishell.yml @@ -11,22 +11,13 @@ on: workflow_dispatch: - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - concurrency: group: aishell-${{ github.ref }} cancel-in-progress: true jobs: generate_build_matrix: - if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'aishell' || github.event_name == 'schedule') + if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell') # see https://github.com/pytorch/pytorch/pull/50633 runs-on: ubuntu-latest diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/librispeech.yml similarity index 95% rename from .github/workflows/train-librispeech.yml rename to .github/workflows/librispeech.yml index 79002a881..6e087b10a 100644 --- a/.github/workflows/train-librispeech.yml +++ b/.github/workflows/librispeech.yml @@ -1,4 +1,4 @@ -name: train librispeech +name: librispeech on: push: branches: @@ -11,7 +11,7 @@ on: workflow_dispatch: concurrency: - group: train-librispeech-${{ github.ref }} + group: librispeech-${{ github.ref }} cancel-in-progress: true jobs: @@ -32,7 +32,7 @@ jobs: python ./.github/scripts/docker/generate_build_matrix.py MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) echo "::set-output name=matrix::${MATRIX}" - train-librispeech: + librispeech: needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} runs-on: ubuntu-latest diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml deleted file mode 100644 index f092e3c80..000000000 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-03-12 -# stateless transducer + k2 pruned rnnt-loss - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_03_12-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_03_12: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh - - - name: Display decoding results for pruned_transducer_stateless - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless/exp - - cd pruned_transducer_stateless - echo "results for pruned_transducer_stateless" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for pruned_transducer_stateless - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless-2022-03-12 - path: egs/librispeech/ASR/pruned_transducer_stateless/exp/ diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml deleted file mode 100644 index f8f4d9977..000000000 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-04-29 -# stateless pruned transducer (reworked model) + giga speech - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_04_29-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_04_29: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh - - .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh - - - name: Display decoding results for pruned_transducer_stateless2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR - tree pruned_transducer_stateless2/exp - cd pruned_transducer_stateless2/exp - echo "===greedy search===" - find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Display decoding results for pruned_transducer_stateless3 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR - tree pruned_transducer_stateless3/exp - cd pruned_transducer_stateless3/exp - echo "===greedy search===" - find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for pruned_transducer_stateless2 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-04-29 - path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/ - - - name: Upload decoding results for pruned_transducer_stateless3 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29 - path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/ diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml deleted file mode 100644 index dc20185da..000000000 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-05-13 -# stateless transducer + k2 pruned rnnt-loss + deeper model - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_05_13-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_05_13: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless5 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless5/exp - - cd pruned_transducer_stateless5 - echo "results for pruned_transducer_stateless5" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless5 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless5-2022-05-13 - path: egs/librispeech/ASR/pruned_transducer_stateless5/exp/ diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml deleted file mode 100644 index 7e378c9a1..000000000 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-11-11-stateless7 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_11_11_zipformer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless7 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless7/exp - - cd pruned_transducer_stateless7 - echo "results for pruned_transducer_stateless7" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless7 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-2022-11-11 - path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/ diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml deleted file mode 100644 index a2c1a0ad6..000000000 --- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-11-14-stateless8 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_11_14_zipformer_stateless8-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_11_14_zipformer_stateless8: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless8 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless8/exp - - cd pruned_transducer_stateless8 - echo "results for pruned_transducer_stateless8" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless8 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless8-2022-11-14 - path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml deleted file mode 100644 index 500ab1736..000000000 --- a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-12-01-stateless7-ctc -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -jobs: - run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless7_ctc/exp - - cd pruned_transducer_stateless7_ctc - echo "results for pruned_transducer_stateless7_ctc" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===ctc decoding===" - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-2022-12-01 - path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml deleted file mode 100644 index 1a7f9f594..000000000 --- a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2022 Zengwei Yao - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-12-08-zipformer-mmi -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_12_08_zipformer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_12_08_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh - - - name: Display decoding results for librispeech zipformer-mmi - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./zipformer-mmi/exp - - cd zipformer-mmi - echo "results for zipformer-mmi" - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===nbest===" - find exp/nbest -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/nbest -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===nbest-rescoring-LG===" - find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===nbest-rescoring-3-gram===" - find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===nbest-rescoring-4-gram===" - find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech zipformer-mmi - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer_mmi-2022-12-08 - path: egs/librispeech/ASR/zipformer_mmi/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml deleted file mode 100644 index 68014e20c..000000000 --- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-12-29-stateless7-streaming -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_12_29_zipformer_streaming-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_12_29_zipformer_streaming: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless7_streaming - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless7_streaming/exp - - cd pruned_transducer_stateless7_streaming - echo "results for pruned_transducer_stateless7_streaming" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===streaming greedy search===" - find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===streaming fast_beam_search===" - find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===streaming modified beam search===" - find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - - name: Upload decoding results for librispeech pruned_transducer_stateless7_streaming - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-streaming-2022-12-29 - path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/ diff --git a/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml deleted file mode 100644 index 821abc25d..000000000 --- a/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2023-01-29-stateless7-ctc-bs -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -jobs: - run_librispeech_2023_01_29_zipformer_ctc_bs: - if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless7_ctc_bs/exp - - cd pruned_transducer_stateless7_ctc_bs - echo "results for pruned_transducer_stateless7_ctc_bs" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===ctc decoding===" - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc_bs - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2023-01-29 - path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/ diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml deleted file mode 100644 index 905515dc4..000000000 --- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-conformer-ctc3-2022-11-28 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_11_28_conformer_ctc3-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_11_28_conformer_ctc3: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh - - - name: Display decoding results for librispeech conformer_ctc3 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./conformer_ctc3/exp - - cd conformer_ctc3 - echo "results for conformer_ctc3" - echo "===ctc-decoding===" - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech conformer_ctc3 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-conformer_ctc3-2022-11-28 - path: egs/librispeech/ASR/conformer_ctc3/exp/ diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml deleted file mode 100644 index 3fb0920bc..000000000 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-pruned-transducer-stateless3-2022-05-13 -# stateless pruned transducer (reworked model) + giga speech - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_pruned_transducer_stateless3_2022_05_13-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_pruned_transducer_stateless3_2022_05_13: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh - - - name: Display decoding results for pruned_transducer_stateless3 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR - tree pruned_transducer_stateless3/exp - cd pruned_transducer_stateless3/exp - echo "===greedy search===" - find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for pruned_transducer_stateless3 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29 - path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/ diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml deleted file mode 100644 index 67a6f6fc4..000000000 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-streaming-2022-06-26 -# streaming conformer stateless transducer2 - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_streaming_2022_06_26-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_streaming_2022_06_26: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh - - - name: Display decoding results - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless2/exp - - cd pruned_transducer_stateless2 - echo "results for pruned_transducer_stateless2" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified_beam_search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for pruned_transducer_stateless2 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-06-26 - path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/ diff --git a/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml b/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml deleted file mode 100644 index 5145fb43c..000000000 --- a/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-streaming-zipformer-2023-05-18 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2023_05_18_streaming_zipformer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2023_05_18_streaming_zipformer: - if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh - - - name: Display decoding results for librispeech zipformer - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./zipformer/exp - - cd zipformer - - echo "results for zipformer, simulated streaming decoding" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "results for zipformer, chunk-wise streaming decoding" - echo "===greedy search===" - find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - - name: Upload decoding results for librispeech zipformer - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 - path: egs/librispeech/ASR/zipformer/exp/ diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml deleted file mode 100644 index 35ca08a31..000000000 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-04-19 -# stateless transducer + torchaudio rnn-t loss - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_04_19-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_04_19: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh - - - name: Display decoding results - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./transducer_stateless2/exp - - cd transducer_stateless2 - echo "results for transducer_stateless2" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified_beam_search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for transducer_stateless2 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless2-2022-04-19 - path: egs/librispeech/ASR/transducer_stateless2/exp/ diff --git a/.github/workflows/run-librispeech-zipformer-2023-05-18.yml b/.github/workflows/run-librispeech-zipformer-2023-05-18.yml deleted file mode 100644 index e9d235ad1..000000000 --- a/.github/workflows/run-librispeech-zipformer-2023-05-18.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-zipformer-2023-05-18 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2023_05_18_zipformer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2023_05_18_zipformer: - if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-zipformer-2023-05-18.sh - - - name: Display decoding results for librispeech zipformer - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./zipformer/exp - - cd zipformer - echo "results for zipformer" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech zipformer - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 - path: egs/librispeech/ASR/zipformer/exp/ diff --git a/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml b/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml deleted file mode 100644 index 48f0b1532..000000000 --- a/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-zipformer-ctc-2023-06-14 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2023_06_14_zipformer-ctc-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2023_06_14_zipformer_ctc: - if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh - - - name: Display decoding results for librispeech zipformer - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./zipformer/exp - - cd zipformer - echo "results for zipformer" - echo "===ctc-decoding===" - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech zipformer - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 - path: egs/librispeech/ASR/zipformer/exp/ diff --git a/.github/workflows/run-pretrained-ctc.yml b/.github/workflows/run-pretrained-ctc.yml deleted file mode 100644 index 074a63dfc..000000000 --- a/.github/workflows/run-pretrained-ctc.yml +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-ctc - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - workflow_dispatch: - inputs: - test-run: - description: 'Test (y/n)?' - required: true - default: 'y' - -concurrency: - group: run_pre_trained_ctc-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-ctc.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml deleted file mode 100644 index f8caee8e5..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-100h - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh - - - name: Display decoding results for transducer_stateless_multi_datasets - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./transducer_stateless_multi_datasets/exp - - cd transducer_stateless_multi_datasets - echo "results for transducer_stateless_multi_datasets" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for transducer_stateless_multi_datasets - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-02-21 - path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/ diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml deleted file mode 100644 index 7c3910eb8..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-960h - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh - - - name: Display decoding results for transducer_stateless_multi_datasets - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./transducer_stateless_multi_datasets/exp - - cd transducer_stateless_multi_datasets - echo "results for transducer_stateless_multi_datasets" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for transducer_stateless_multi_datasets - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-03-01 - path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/ diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml deleted file mode 100644 index 1b69b97bf..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-transducer-stateless - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_pre_trained_transducer_stateless-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-pre-trained-transducer-stateless.sh - - - name: Display decoding results for transducer_stateless - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./transducer_stateless/exp - - cd transducer_stateless - echo "results for transducer_stateless" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for transducer_stateless - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless-2022-02-07 - path: egs/librispeech/ASR/transducer_stateless/exp/ diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml deleted file mode 100644 index 91d87f1c9..000000000 --- a/.github/workflows/run-pretrained-transducer.yml +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-transducer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - -concurrency: - group: run_pre_trained_transducer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer: - if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - make -j2 _kaldifeat - - - name: Inference with pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-transducer.sh From f42258caf8a1c4d19428d98b808986522f630843 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 30 Dec 2023 13:03:26 +0800 Subject: [PATCH 073/216] Update compute_fbank_commonvoice_splits.py (#1437) --- egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py index 0564f6ec6..f31b45aa5 100755 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py @@ -109,10 +109,10 @@ def compute_fbank_commonvoice_splits(args): extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") - set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance + set_audio_duration_mismatch_tolerance(0.05) # 50ms tolerance set_caching_enabled(False) for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) + idx = f"{i}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz" From 8136ad775b6cd02bf2ecc60d65e8641b709c2d41 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 4 Jan 2024 13:59:32 +0800 Subject: [PATCH 074/216] Use high_freq -400 in computing fbank features. (#1447) See also https://github.com/k2-fsa/sherpa-onnx/issues/514 --- .../ASR/pruned_transducer_stateless2/pretrained.py | 1 + egs/aishell/ASR/conformer_ctc/pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py | 1 + .../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 + egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 1 + egs/aishell/ASR/transducer_stateless/pretrained.py | 1 + egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py | 1 + egs/aishell/ASR/transducer_stateless_modified/pretrained.py | 1 + egs/aishell/ASR/zipformer/streaming_decode.py | 1 + egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py | 1 + egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py | 1 + egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7/onnx_pretrained.py | 1 + egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py | 1 + .../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 + .../jit_trace_pretrained.py | 1 + egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py | 1 + .../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 + egs/gigaspeech/ASR/zipformer/streaming_decode.py | 1 + egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py | 1 + .../ASR/conformer_ctc/jit_pretrained_decode_with_H.py | 1 + .../ASR/conformer_ctc/jit_pretrained_decode_with_HL.py | 1 + .../ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py | 1 + egs/librispeech/ASR/conformer_ctc/pretrained.py | 1 + egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py | 1 + egs/librispeech/ASR/conformer_ctc3/pretrained.py | 1 + .../ASR/conv_emformer_transducer_stateless/streaming_decode.py | 1 + .../ASR/conv_emformer_transducer_stateless2/jit_pretrained.py | 1 + .../ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py | 1 + .../conv_emformer_transducer_stateless2/streaming-ncnn-decode.py | 1 + .../ASR/conv_emformer_transducer_stateless2/streaming_decode.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py | 1 + .../ASR/lstm_transducer_stateless/streaming_decode.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py | 1 + .../ASR/lstm_transducer_stateless2/onnx_pretrained.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py | 1 + .../ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py | 1 + .../ASR/lstm_transducer_stateless2/streaming-onnx-decode.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py | 1 + .../ASR/lstm_transducer_stateless3/streaming_decode.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py | 1 + .../ASR/pruned_transducer_stateless/streaming_decode.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py | 1 + .../ASR/pruned_transducer_stateless2/streaming_decode.py | 1 + .../ASR/pruned_transducer_stateless3/jit_pretrained.py | 1 + .../ASR/pruned_transducer_stateless3/onnx_pretrained.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py | 1 + .../ASR/pruned_transducer_stateless3/streaming_decode.py | 1 + .../ASR/pruned_transducer_stateless4/streaming_decode.py | 1 + .../pruned_transducer_stateless5/onnx_pretrained-streaming.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py | 1 + .../ASR/pruned_transducer_stateless5/streaming_decode.py | 1 + .../ASR/pruned_transducer_stateless7/jit_pretrained.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py | 1 + .../ASR/pruned_transducer_stateless7_ctc/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py | 1 + .../ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py | 1 + .../pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py | 1 + .../ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py | 1 + .../ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py | 1 + .../jit_trace_pretrained.py | 1 + .../pruned_transducer_stateless7_streaming/onnx_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_streaming/pretrained.py | 1 + .../streaming-ncnn-decode.py | 1 + .../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 + .../ASR/pruned_transducer_stateless8/jit_pretrained.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py | 1 + egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py | 1 + egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py | 1 + egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py | 1 + egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py | 1 + egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py | 1 + egs/librispeech/ASR/transducer/pretrained.py | 1 + egs/librispeech/ASR/transducer_stateless/pretrained.py | 1 + egs/librispeech/ASR/transducer_stateless2/pretrained.py | 1 + .../ASR/transducer_stateless_multi_datasets/pretrained.py | 1 + egs/librispeech/ASR/zipformer/jit_pretrained.py | 1 + egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py | 1 + egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py | 1 + egs/librispeech/ASR/zipformer/pretrained.py | 1 + egs/librispeech/ASR/zipformer/pretrained_ctc.py | 1 + egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py | 1 + egs/librispeech/ASR/zipformer_mmi/pretrained.py | 1 + egs/mgb2/ASR/conformer_ctc/pretrained.py | 1 + egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py | 1 + egs/multi_zh-hans/ASR/zipformer/pretrained.py | 1 + egs/multi_zh_en/ASR/zipformer/pretrained.py | 1 + egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_bbpe/pretrained.py | 1 + egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py | 1 + egs/tedlium3/ASR/transducer_stateless/pretrained.py | 1 + egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 1 + egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 1 + .../ASR/pruned_transducer_stateless2/jit_pretrained.py | 1 + egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py | 1 + .../pruned_transducer_stateless5/onnx_pretrained-streaming.py | 1 + .../ASR/pruned_transducer_stateless5/onnx_pretrained.py | 1 + egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py | 1 + .../ASR/pruned_transducer_stateless5/streaming_decode.py | 1 + egs/wenetspeech/ASR/zipformer/streaming_decode.py | 1 + egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py | 1 + egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py | 1 + egs/yesno/ASR/tdnn/jit_pretrained.py | 1 + egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py | 1 + egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py | 1 + egs/yesno/ASR/tdnn/onnx_pretrained.py | 1 + egs/yesno/ASR/tdnn/pretrained.py | 1 + 127 files changed, 127 insertions(+) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py index 75c316eaf..17729e02e 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -242,6 +242,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 66d583396..af1171a6f 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -261,6 +261,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py index 82c10f129..c4aa98358 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py @@ -240,6 +240,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py index ead393e6e..69fe3a40b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -241,6 +241,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py index e61190649..5143f2cae 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -230,6 +230,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py index a92182e8d..8e8e971eb 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -369,6 +369,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py index 0c43bf74b..8fb7ac278 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -227,6 +227,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py index ea5bda4db..12004315b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -250,6 +250,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 6b4f183cf..aa0e07c83 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -317,6 +317,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 7e7213501..9754b4939 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -158,6 +158,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index 40f430e13..540e7b61b 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -258,6 +258,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index 5d8ca2e11..4a4e9237c 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -238,6 +238,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 9e4459247..66a91709e 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -238,6 +238,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py index c3820447a..f54ffbd3c 100755 --- a/egs/aishell/ASR/zipformer/streaming_decode.py +++ b/egs/aishell/ASR/zipformer/streaming_decode.py @@ -572,6 +572,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index bc3ae7abf..f04632388 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -239,6 +239,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index ee898c303..e8b7f71b7 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -251,6 +251,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py index f5a0dd8c8..a738bb3fb 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py @@ -242,6 +242,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py index cf6ddfa36..52fed7331 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -370,6 +370,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py index a22d1b4ba..b6e2451e8 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index dbe65d0a7..018736d26 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -320,6 +320,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py index d84cf04a3..58ee99e6a 100644 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -177,6 +177,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py index 932026868..66fbae378 100644 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -252,6 +252,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 9700dd89e..7252665a7 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -337,6 +337,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py index a76788859..09df2935c 100755 --- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py +++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py @@ -553,6 +553,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py index 48fd2612a..458109a3f 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py @@ -264,6 +264,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py index 4bdec9e11..e9acf7e0b 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -195,6 +195,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py index d5a1dba3c..5753aa5d3 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -192,6 +192,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py index 216677a23..b6e3333ce 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -191,6 +191,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index df3e4d819..38b60fcb9 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -283,6 +283,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py index 76db46cc8..19b26361e 100755 --- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -271,6 +271,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index c37b99cce..a0cdfcf03 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -302,6 +302,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index e5a7c7116..9b8b4cce2 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -623,6 +623,7 @@ def create_streaming_feature_extractor() -> Fbank: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return Fbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py index 1fe358c79..58f587c91 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py @@ -184,6 +184,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py index a6c69d54f..c8aae04e8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -326,6 +326,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py index 74da9e6c8..1047100fc 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py @@ -276,6 +276,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index f5d894a7b..aaed7d31f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -623,6 +623,7 @@ def create_streaming_feature_extractor() -> Fbank: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py index c07956243..5350a54da 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -266,6 +266,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 119fcf1fd..42c3a5d7f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -251,6 +251,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index f989d9bc0..03472e2c3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -615,6 +615,7 @@ def create_streaming_feature_extractor() -> Fbank: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index 728b09104..f4ec17221 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -267,6 +267,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3eeaa5397..5bab70fb0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -255,6 +255,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py index 06159e56a..06397965d 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py @@ -298,6 +298,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index 5d6d97320..dcff088e2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -254,6 +254,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index cbbc77928..6166049ae 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -217,6 +217,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 487fc2114..df9f6cf3f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -344,6 +344,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index 237591a36..d9e7f3578 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -266,6 +266,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 29a0d4d1a..e39637bd8 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -252,6 +252,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index c737e3611..c425b1f46 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -615,6 +615,7 @@ def create_streaming_feature_extractor() -> Fbank: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 02f9f1b03..e06404619 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -277,6 +277,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index f4b01fd06..8586c66d6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -334,6 +334,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 029f55ba0..6923f4d40 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -278,6 +278,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 9c4a13606..d17c3467a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -336,6 +336,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 0669284b3..6d09de6bd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -285,6 +285,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index de3e03da6..8d12eae28 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -368,6 +368,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index abda4e2d4..05e6a6fba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -287,6 +287,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index e7c1affc2..5e1acd735 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -337,6 +337,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index e966aa4b1..229b52e5b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -353,6 +353,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py index 6e290e799..2432c6010 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -326,6 +326,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 304fa8693..a9ce75a7b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -251,6 +251,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index f65f47fc2..8478a65fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -353,6 +353,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index 5af6dae25..88a05e09d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -225,6 +225,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 86c922cda..4bf11ac24 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py index 280b95984..83dc29324 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py @@ -224,6 +224,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py index d50d231d5..d1b7eec65 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -280,6 +280,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 78e0fa778..323ba2642 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 904c1deae..1e638aa7d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -298,6 +298,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py index da2c6a39a..a39fdee54 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py @@ -224,6 +224,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py index 653c25e06..80604ef4a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -280,6 +280,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py index 494a34d97..0ff110370 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -381,6 +381,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index 5d240cf30..a82f3562b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index 914107526..b98756a54 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -298,6 +298,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py index c8301b2da..7116b10fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -231,6 +231,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py index f2ac1914d..d714670cf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -186,6 +186,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 04861ea37..298d1889b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -382,6 +382,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index bc42e8d05..aa2dd17fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py index 883fdcbdd..999f7e0b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -335,6 +335,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index a0f54b6e1..e27fb4e63 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -320,6 +320,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index 129497d5a..3ce2953c3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -225,6 +225,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 64b38c9d5..c29b8d8c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index fde724866..b3dfab64a 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -196,6 +196,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py index 3888d3544..0cd876551 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py @@ -224,6 +224,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py index 6f2cbaabd..92dea3aa1 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py @@ -280,6 +280,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py index 981039b8f..5c6956324 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -262,6 +262,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py index a06d6d684..7698ada79 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py @@ -298,6 +298,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index c2413f5de..4d9bbf4b1 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -235,6 +235,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 5898dd0f5..3b86e319e 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -247,6 +247,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index b69b347ef..2de4182f1 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -247,6 +247,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 4f29d6f1f..83094ea51 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -247,6 +247,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained.py b/egs/librispeech/ASR/zipformer/jit_pretrained.py index a41fbc1c9..52dfd3fb6 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained.py @@ -222,6 +222,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index 660a4bfc6..fcd07ae34 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -285,6 +285,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py index d4ceacefd..eade5a854 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py @@ -167,6 +167,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py index 44546cae5..dd47c0eb6 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -318,6 +318,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index e7c4f40ee..e011c4b24 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -413,6 +413,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index 334376093..662392b5f 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -369,6 +369,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py index eb5cee9cd..ecca758f2 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -161,6 +161,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py index 683a7dc20..a77c3bf2a 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -225,6 +225,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 logging.info(f"Loading H from {args.H}") H = kaldifst.StdVectorFst.read(args.H) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py index 0b94bfa65..6ef944514 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -223,6 +223,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 logging.info(f"Loading HL from {args.HL}") HL = kaldifst.StdVectorFst.read(args.HL) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py index 93569142a..ccb3107ea 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -223,6 +223,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 logging.info(f"Loading HLG from {args.HLG}") HLG = kaldifst.StdVectorFst.read(args.HLG) diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index 3104b6084..de0652893 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -303,6 +303,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 9dff2e6fc..408d13576 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -304,6 +304,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py index c9ef16ffa..6990c90a0 100755 --- a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py @@ -259,6 +259,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 3ba4da5dd..1e7afc777 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -282,6 +282,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py index d30ca98d8..0ab2af527 100755 --- a/egs/mgb2/ASR/conformer_ctc/pretrained.py +++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py @@ -287,6 +287,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py index 77ba0873b..81a16f0ff 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py @@ -249,6 +249,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py index 69ff382da..c15db11f7 100755 --- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py +++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py @@ -303,6 +303,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py index 676272e1f..2fcde550b 100755 --- a/egs/multi_zh_en/ASR/zipformer/pretrained.py +++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py @@ -306,6 +306,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py index 3305f5bd3..8a74ee745 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py @@ -248,6 +248,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py index a23e2a04f..8c966a2f6 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -226,6 +226,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py index f365986f6..6e07b5949 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -261,6 +261,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py index 8a89c3578..9e58fed00 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -256,6 +256,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py index 81afd6a4e..5300fe764 100644 --- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py @@ -270,6 +270,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py index 3fdf3b855..0d77bc512 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py @@ -196,6 +196,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py index 98c746ce5..f06c8c211 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py @@ -196,6 +196,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py index f90dd2b43..aee1a2175 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py @@ -285,6 +285,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index c3d67ad92..642de72d7 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -238,6 +238,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py index c31db6859..cca26feb0 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -327,6 +327,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py index c784853ee..4b4ddd332 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py @@ -376,6 +376,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py index 1cac20435..17428e19d 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -238,6 +238,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 3a4dc3cb8..27a9b1714 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -378,6 +378,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py index 94c5fae5f..96f339b07 100755 --- a/egs/wenetspeech/ASR/zipformer/streaming_decode.py +++ b/egs/wenetspeech/ASR/zipformer/streaming_decode.py @@ -572,6 +572,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py index 74a2210c3..2c106c4cb 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py @@ -249,6 +249,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py index d05bafcfb..6995ff2ff 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index 7581ecb83..e29415ffb 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -142,6 +142,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py index ff8c742af..72127aebd 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py @@ -164,6 +164,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 23 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py index 05ba74f9a..f8a057336 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py @@ -163,6 +163,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 23 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index 72a1d69c8..968a9e9a8 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -186,6 +186,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 987c49de6..bea520998 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -164,6 +164,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) From 716b82cc3ada9e39254acc93465ed85e53d05670 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Fri, 5 Jan 2024 03:21:27 +0100 Subject: [PATCH 075/216] streaming_decode.py, relax the audio range from [-1,+1] to [-10,+10] (#1448) - some AudioTransform classes produce audio signals out of range [-1,+1] - Resample produced 1.0079 - The range [-10,+10] was chosen to still be able to reliably distinguish from the [-32k,+32k] signal... - this is related to : https://github.com/lhotse-speech/lhotse/issues/1254 --- .../streaming_decode.py | 7 ++++++- egs/aishell/ASR/zipformer/streaming_decode.py | 12 ++++++------ .../streaming_decode.py | 7 ++++++- egs/gigaspeech/ASR/zipformer/streaming_decode.py | 7 ++++++- .../streaming_decode.py | 8 +++++++- .../streaming_decode.py | 8 +++++++- .../lstm_transducer_stateless/streaming_decode.py | 8 +++++++- .../lstm_transducer_stateless3/streaming_decode.py | 8 +++++++- .../pruned_transducer_stateless/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless2/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless3/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless4/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless5/streaming_decode.py | 7 ++++++- .../streaming_decode.py | 7 ++++++- .../streaming_decode.py | 7 ++++++- egs/librispeech/ASR/zipformer/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless5/streaming_decode.py | 8 ++++++++ egs/wenetspeech/ASR/zipformer/streaming_decode.py | 12 ++++++------ 18 files changed, 114 insertions(+), 27 deletions(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index aa0e07c83..a4b5cd588 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -342,7 +342,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py index f54ffbd3c..6a7ef2750 100755 --- a/egs/aishell/ASR/zipformer/streaming_decode.py +++ b/egs/aishell/ASR/zipformer/streaming_decode.py @@ -597,12 +597,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - if audio.max() > 1: - logging.warning( - f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." - f"Clipping to [-1, 1]." - ) - audio = np.clip(audio, -1, 1) + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 7252665a7..6a249dd3f 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -362,7 +362,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py index 09df2935c..7cada8c9d 100755 --- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py +++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py @@ -578,7 +578,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 9b8b4cce2..12953c74c 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -681,8 +681,14 @@ def decode_dataset( assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) feature = fbank(samples) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index aaed7d31f..ddc7dbef1 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -681,8 +681,14 @@ def decode_dataset( assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) feature = fbank(samples) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index 03472e2c3..14cb0fdfe 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -673,8 +673,14 @@ def decode_dataset( assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) feature = fbank(samples) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index c425b1f46..f57bdea67 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -673,8 +673,14 @@ def decode_dataset( assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) feature = fbank(samples) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index 8586c66d6..4726d9fad 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -359,7 +359,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d17c3467a..381561359 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -361,7 +361,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 5e1acd735..9113cfaa9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -362,7 +362,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 229b52e5b..f205ad42f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -378,7 +378,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 8478a65fb..1d980f10e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -378,7 +378,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index e27fb4e63..0961e0d7b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -345,7 +345,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py index 2904f086c..cc2787d76 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py @@ -345,7 +345,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 904caf8af..8087c1460 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -577,7 +577,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 27a9b1714..b396aa9b8 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -402,6 +402,14 @@ def decode_dataset( assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + samples = torch.from_numpy(audio).squeeze(0) fbank = Fbank(opts) diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py index 96f339b07..cb2cf7d35 100755 --- a/egs/wenetspeech/ASR/zipformer/streaming_decode.py +++ b/egs/wenetspeech/ASR/zipformer/streaming_decode.py @@ -597,12 +597,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - if audio.max() > 1: - logging.warning( - f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." - f"Clipping to [-1, 1]." - ) - audio = np.clip(audio, -1, 1) + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) From b9b56eb879e694684156b6ba441a1c665ff26e19 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 8 Jan 2024 14:28:07 +0800 Subject: [PATCH 076/216] Minor fixes to the VCTK data prep scripts (#1441) * Update prepare.sh --- egs/vctk/TTS/prepare.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh index 87150ad31..152c7b168 100755 --- a/egs/vctk/TTS/prepare.sh +++ b/egs/vctk/TTS/prepare.sh @@ -7,6 +7,7 @@ set -eou pipefail stage=0 stop_stage=100 +use_edinburgh_vctk_url=true dl_dir=$PWD/download @@ -44,7 +45,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # ln -sfv /path/to/VCTK $dl_dir/VCTK # if [ ! -d $dl_dir/VCTK ]; then - lhotse download vctk $dl_dir + lhotse download vctk --use-edinburgh-vctk-url ${use_edinburgh_vctk_url} $dl_dir fi fi @@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/VCTK mkdir -p data/manifests if [ ! -e data/manifests/.vctk.done ]; then - lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests + lhotse prepare vctk --use-edinburgh-vctk-url ${use_edinburgh_vctk_url} $dl_dir/VCTK data/manifests touch data/manifests/.vctk.done fi fi From 5445ea6df6e250f8f1e9b2df3bb9e54afe104f97 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 8 Jan 2024 15:09:21 +0800 Subject: [PATCH 077/216] Use shuffled LibriSpeech cuts instead (#1450) * use shuffled LibriSpeech cuts instead * leave the old code in comments for reference --- egs/librispeech/ASR/conformer_ctc3/train.py | 15 ++++++++++++--- egs/librispeech/ASR/conformer_mmi/train.py | 16 +++++++++++++--- .../ASR/lstm_transducer_stateless3/train.py | 15 ++++++++++++--- .../ASR/pruned2_knowledge/train.py | 15 ++++++++++++--- .../train.py | 19 ++++++++++++++++--- .../train.py | 11 ++++++++--- egs/librispeech/ASR/zipformer/train.py | 15 ++++++++++++--- egs/librispeech/ASR/zipformer_mmi/train.py | 8 +++++--- 8 files changed, 90 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index 2cd223945..a2f1125ca 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -952,10 +952,19 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts + # strictly speaking, shuffled training cuts should be used instead + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index f9f80632e..fe8c85f61 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -771,10 +771,20 @@ def run(rank, world_size, args): valid_ali = None librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 6ef4c9860..2c1cef3a3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -989,10 +989,19 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index a4899f7bd..931341cc4 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -817,10 +817,19 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 2d915ff87..e1bdce49d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1038,13 +1038,26 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) + assert not ( + params.mini_libri and params.full_libri + ), f"Cannot set both mini-libri and full-libri flags to True, now mini-libri {params.mini_libri} and full-libri {params.full_libri}" + if params.mini_libri: train_cuts = librispeech.train_clean_5_cuts() else: - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index 565dc7a16..1642ef4b7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -1150,10 +1150,15 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 7009f3346..3ccf7d2f1 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1174,10 +1174,19 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index 4b50acdde..dd8949523 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -990,11 +990,13 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - # train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - # train_cuts += librispeech.train_clean_360_cuts() - # train_cuts += librispeech.train_other_500_cuts() train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets else: train_cuts = librispeech.train_clean_100_cuts() From e2fcb42f5f176d9e39eb38506ab99d0a3adaf202 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 9 Jan 2024 15:41:37 +0800 Subject: [PATCH 078/216] fix typo (#1455) --- .../RNN-LM/librispeech/lm-training.rst | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst index 46499a374..e0c90f2a6 100644 --- a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst +++ b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst @@ -4,7 +4,7 @@ Train an RNN language model ====================================== If you have enough text data, you can train a neural network language model (NNLM) to improve -the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from +the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from scratch. .. HINT:: @@ -15,23 +15,23 @@ scratch. .. note:: This tutorial is based on the LibriSpeech recipe. Please check it out for the necessary - python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set + python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set for illustration purpose. You can also collect your own data. The data format is quite simple: each line should contain a complete sentence, and words should be separated by space. -First, let's download the training data for the RNNLM. This can be done via the +First, let's download the training data for the RNNLM. This can be done via the following command: .. code-block:: bash - $ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz + $ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz $ gzip -d librispeech-lm-norm.txt.gz As we are training a BPE-level RNNLM, we need to tokenize the training text, which requires a BPE tokenizer. This can be achieved by executing the following command: .. code-block:: bash - + $ # if you don't have the BPE $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 $ cd icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500 @@ -56,11 +56,11 @@ sentence length. --out-statistics data/lang_bpe_500/lm_data_stats.txt -The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say -you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt`` +The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say +you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt`` and ``--lm-archive data/lang_bpe_500/lm-data-valid.pt`` when calling ``./local/prepare_lm_training_data.py``. -After completing the previous steps, the training and testing sets for training RNNLM are ready. +After completing the previous steps, the training and testing sets for training RNNLM are ready. The next step is to train the RNNLM model. The training command is as follows: .. code-block:: bash @@ -77,7 +77,7 @@ The next step is to train the RNNLM model. The training command is as follows: --use-fp16 0 \ --tie-weights 1 \ --embedding-dim 2048 \ - --hidden_dim 2048 \ + --hidden-dim 2048 \ --num-layers 3 \ --batch-size 300 \ --lm-data rnn_lm/data/lang_bpe_500/sorted_lm_data.pt \ @@ -93,12 +93,3 @@ The next step is to train the RNNLM model. The training command is as follows: .. note:: The training of RNNLM can take a long time (usually a couple of days). - - - - - - - - - From 398401ed277d4f895f624a95919c57edbbde4cba Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 14 Jan 2024 14:38:41 +0800 Subject: [PATCH 079/216] Update kaldifeat installation doc (#1460) --- docs/source/for-dummies/environment-setup.rst | 4 ++-- docs/source/for-dummies/model-export.rst | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/for-dummies/environment-setup.rst b/docs/source/for-dummies/environment-setup.rst index 0cb8ecc1d..a68e9d3ed 100644 --- a/docs/source/for-dummies/environment-setup.rst +++ b/docs/source/for-dummies/environment-setup.rst @@ -66,13 +66,13 @@ to install dependencies of `icefall`_: pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - # If you are using macOS or Windows, please use the following command to install torch and torchaudio + # If you are using macOS, please use the following command to install torch and torchaudio # pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html # Now install k2 # Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example - pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html + pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html # Install the latest version of lhotse diff --git a/docs/source/for-dummies/model-export.rst b/docs/source/for-dummies/model-export.rst index 079ebc712..352a0dc90 100644 --- a/docs/source/for-dummies/model-export.rst +++ b/docs/source/for-dummies/model-export.rst @@ -85,7 +85,7 @@ We can also use it to decode files with the following command: # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html # for how to install kaldifeat - pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html ./tdnn/pretrained.py \ --checkpoint ./tdnn/exp/pretrained.pt \ @@ -162,7 +162,7 @@ To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use: # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html # for how to install kaldifeat - pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html ./tdnn/jit_pretrained.py \ @@ -249,7 +249,7 @@ To use the generated ONNX model files for decoding with `onnxruntime`_, we can u # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html # for how to install kaldifeat - pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html ./tdnn/onnx_pretrained.py \ --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ From 7bdde9174c7c95a32a10d6dcbc3764ecb4873b1d Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 16 Jan 2024 21:08:35 +0800 Subject: [PATCH 080/216] A Zipformer recipe with Byte-level BPE for Aishell-1 (#1464) * init commit * Update train.py * Update decode.py * Update RESULTS.md * added `vocab_size` * removed unused softlinks * added scripts for testing pretrained models * set `bpe_model` as required * re-org the bbpe recipe for aishell --- egs/aishell/ASR/RESULTS.md | 56 +- egs/aishell/ASR/zipformer/decode_bbpe.py | 840 ++++++++++++++++ .../ASR/zipformer/jit_pretrained_bbpe.py | 279 ++++++ egs/aishell/ASR/zipformer/pretrained_bbpe.py | 403 ++++++++ egs/aishell/ASR/zipformer/train_bbpe.py | 942 ++++++++++++++++++ 5 files changed, 2518 insertions(+), 2 deletions(-) create mode 100755 egs/aishell/ASR/zipformer/decode_bbpe.py create mode 100755 egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py create mode 100755 egs/aishell/ASR/zipformer/pretrained_bbpe.py create mode 100755 egs/aishell/ASR/zipformer/train_bbpe.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 0b22f41a1..ff9504274 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -2,9 +2,61 @@ ### Aishell training result (Stateless Transducer) +#### Zipformer (Byte-level BPE) + +[./zipformer](./zipformer/) + +It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `vocab_size` set to 500. + +##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M + +| | test | dev | comment | +|------------------------|------|------|-----------------------------------------| +| greedy search | 4.54 | 4.31 | --epoch 40 --avg 10 | +| modified beam search | 4.37 | 4.11 | --epoch 40 --avg 10 | +| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 | + +```bash +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1" + +./zipformer/train_bbpe.py \ + --world-size 2 \ + --num-epochs 40 \ + --start-epoch 1 \ + --use-fp16 1 \ + --context-size 2 \ + --enable-musan 0 \ + --exp-dir zipformer/exp_bbpe \ + --max-duration 1000 \ + --enable-musan 0 \ + --base-lr 0.045 \ + --lr-batches 7500 \ + --lr-epochs 10 \ + --spec-aug-time-warp-factor 20 +``` + +Command for decoding is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./zipformer/decode_bbpe.py \ + --epoch 40 \ + --avg 10 \ + --exp-dir ./zipformer_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --context-size 2 \ + --decoding-method $m +done +``` +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + + #### Zipformer (Non-streaming) -[./zipformer](./zipformer) +[./zipformer](./zipformer/) It's reworked Zipformer with Pruned RNNT loss. **Caution**: It uses `--context-size=1`. @@ -260,7 +312,7 @@ done Pretrained models, training logs, decoding logs, and decoding results are available at -#### Pruned transducer stateless 7 (zipformer) +#### Pruned transducer stateless 7 (Byte-level BPE) See diff --git a/egs/aishell/ASR/zipformer/decode_bbpe.py b/egs/aishell/ASR/zipformer/decode_bbpe.py new file mode 100755 index 000000000..1ec10b059 --- /dev/null +++ b/egs/aishell/ASR/zipformer/decode_bbpe.py @@ -0,0 +1,840 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Mingshuang Luo, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode_bbpe.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode_bbpe.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode_bbpe.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode_bbpe.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode_bbpe.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_bbpe/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the byte BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_500/", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + lexicon: Lexicon, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + hyps.append([lexicon.word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + ref_texts = [] + for tx in supervisions["text"]: + ref_texts.append(byte_encode(tokenize_by_CJK_char(tx))) + + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(ref_texts), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + lexicon: + directory containing the lexicon. + sp: + SentencePiece model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + lexicon=lexicon, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = "".join(ref_text.split()) + + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + lexicon = Lexicon(params.lang_dir) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py new file mode 100755 index 000000000..cd16284f7 --- /dev/null +++ b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer_bbpe/exp \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained.py \ + --nn-model-filename ./zipformer_bbpe/exp/cpu_jit.pt \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + required=True, + help="""Path to the bbpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = model.decoder.blank_id + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + features=features, + feature_lengths=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = smart_byte_decode(sp.decode(hyp)) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/zipformer/pretrained_bbpe.py b/egs/aishell/ASR/zipformer/pretrained_bbpe.py new file mode 100755 index 000000000..387bef98a --- /dev/null +++ b/egs/aishell/ASR/zipformer/pretrained_bbpe.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp_bbpe \ + --tokens ./data/lang_bbpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp_bbpe \ + --causal 1 \ + --tokens ./data/lang_bbpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp_bbpe/epoch-xx.pt`. + +Note: ./zipformer/exp_bbpe/pretrained.pt is generated by ./zipformer/export_bbpe.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + required=True, + help="""Path to the bbpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py new file mode 100755 index 000000000..a2bf96b29 --- /dev/null +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -0,0 +1,942 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./zipformer/train_bbpe.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --exp-dir zipformer/exp_bbpe \ + --max-duration 350 + +# For mix precision training: + +./zipformer/train_bbpe.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp_bbpe \ + --max-duration 750 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from typing import Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from train import ( + LRSchedulerType, + add_model_arguments, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + set_batch_count, +) + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + tokenize_by_CJK_char, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_bbpe/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the Byte BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + valid_cuts = aishell.valid_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 15.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_cuts = train_cuts.map(tokenize_and_encode_text) + + valid_cuts = valid_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The sentence piece model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 5dfc3ed7f995f30a4b64e57071a095566f8126ea Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 21 Jan 2024 02:10:42 +0800 Subject: [PATCH 081/216] Fix buffer size of DynamicBucketingSampler (#1468) * Fix buffer size * Fix for flake8 --------- Co-authored-by: yifanyeung --- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++- egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++ .../ASR/transducer_stateless_modified-2/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 3 ++- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++- .../ASR_v2/pruned_transducer_stateless7/asr_datamodule.py | 2 ++ egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 ++ egs/ami/SURT/dprnn_zipformer/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 ++ .../commonvoice_fr.py | 2 ++ egs/csj/ASR/local/utils/asr_datamodule.py | 2 ++ egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 ++ egs/gigaspeech/ASR/zipformer/asr_datamodule.py | 2 ++ egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py | 2 ++ egs/libriheavy/ASR/zipformer/asr_datamodule.py | 2 ++ egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py | 2 ++ egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless3/asr_datamodule.py | 6 ++++++ .../ASR/pruned_transducer_stateless7/gigaspeech.py | 2 ++ egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++ egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py | 2 ++ egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py | 2 ++ egs/ljspeech/TTS/vits/tts_datamodule.py | 2 ++ egs/mgb2/ASR/conformer_ctc/asr_datamodule.py | 2 ++ egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py | 2 ++ egs/multi_zh_en/ASR/zipformer/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 ++ egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 3 ++- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 3 ++- egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py | 2 ++ egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++ egs/vctk/TTS/vits/tts_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 ++ egs/yesno/ASR/tdnn/asr_datamodule.py | 2 ++ 37 files changed, 78 insertions(+), 6 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index d491996b2..e29dd8ab5 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -288,8 +288,9 @@ class Aidatatang_200zhAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, - buffer_size=50000, ) else: logging.info("Using SimpleCutSampler.") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index 6abe6c084..aacbd153d 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -275,6 +275,8 @@ class AishellAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index cd8dd821c..ed453afd2 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -226,6 +226,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py index 8f6a88f59..f9cdfb621 100644 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -296,6 +296,8 @@ class AiShell2AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index e6db2651f..c10456da5 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -306,7 +306,8 @@ class Aishell4AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=100000, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index 5ad80817a..410741215 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -288,7 +288,8 @@ class AlimeetingAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=30000, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py index 9d288218a..6b56c8a6a 100644 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py @@ -263,6 +263,8 @@ class AlimeetingAsrDataModule: max_cuts=self.args.max_cuts, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) logging.info("About to create train dataloader") diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py index 79474f1d8..554facfc1 100644 --- a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -269,6 +269,8 @@ class AmiAsrDataModule: max_cuts=self.args.max_cuts, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) logging.info("About to create train dataloader") diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py index 1549c1631..ea8b62242 100644 --- a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py @@ -254,6 +254,8 @@ class AmiAsrDataModule: max_cuts=self.args.max_cuts, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index 546e9f9dd..c40d9419b 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -308,6 +308,8 @@ class CommonVoiceAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py index cafa4111d..79cf86b84 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py @@ -310,6 +310,8 @@ class CommonVoiceAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py index 042b6ecbf..7bf7bdef0 100644 --- a/egs/csj/ASR/local/utils/asr_datamodule.py +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -336,6 +336,8 @@ class CSJAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index a93e224d5..569978424 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -261,6 +261,8 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index b5b27ce95..40339365c 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -294,6 +294,8 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 6adfdbfbb..850ab7c10 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -311,6 +311,8 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py index c1abdbdb5..500df9ea4 100644 --- a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py @@ -256,6 +256,8 @@ class LibriCssAsrDataModule: max_cuts=self.args.max_cuts, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py index df761c1b8..e23c9b1b7 100644 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -310,6 +310,8 @@ class LibriHeavyAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 690003377..1a4c9a532 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -341,6 +341,8 @@ class LibriHeavyAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index ee7556e49..be36c06b6 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -286,6 +286,8 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, bucket_method="equal_duration", drop_last=True, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 057624272..87c62789e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -223,6 +223,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) @@ -256,6 +258,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=False, ) logging.info("About to create dev dataloader") @@ -282,6 +286,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, ) logging.debug("About to create test dataloader") test_dl = DataLoader( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py index cd432fd6f..306f30c2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py @@ -294,6 +294,8 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c500eb3e5..dd9e9ef1f 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -311,6 +311,8 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py index 3acd22ae4..84bd3fc4b 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py @@ -304,6 +304,8 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py index 2f8e658c5..e1a29bd9c 100644 --- a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py +++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py @@ -227,6 +227,8 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 81bb9ed13..8ff868bc8 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -196,6 +196,8 @@ class LJSpeechTtsDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py index 7753d1674..48921d71f 100644 --- a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py @@ -266,6 +266,8 @@ class MGB2AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 02cfa1346..341579acb 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -297,6 +297,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py index be6e94472..662ae01c5 100644 --- a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py @@ -294,6 +294,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index cf70fc0f8..7cd6771ce 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -236,6 +236,8 @@ class SPGISpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) logging.info("About to create train dataloader") diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index ce8634a1d..0f6f02e8d 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -298,8 +298,9 @@ class SwitchBoardAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, - buffer_size=50000, ) else: logging.info("Using SimpleCutSampler.") diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 5269a1778..6f0833db6 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -306,8 +306,9 @@ class TAL_CSASRAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, num_cuts_for_bins_estimate=20000, - buffer_size=60000, drop_last=self.args.drop_last, ) else: diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index d4a9e4bc9..a67cf8d04 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -256,6 +256,8 @@ class TedLiumAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py index 5d1b3c367..8606a490b 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -222,6 +222,8 @@ class TimitAsrDataModule(DataModule): max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py index 8b2a96b09..52fc5179f 100644 --- a/egs/vctk/TTS/vits/tts_datamodule.py +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -204,6 +204,8 @@ class VctkTtsDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 1dbfb9709..58da1d68c 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -292,7 +292,8 @@ class WenetSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=300000, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py index 7594fb28e..7b37b1331 100644 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -296,6 +296,8 @@ class Xbmu_AmdoAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index dc66b217d..b9ce8fb4e 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -193,6 +193,8 @@ class YesNoAsrDataModule(DataModule): max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: From ebe97a07b082f9bee4a18b5f8e54c453187a74bb Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 23 Jan 2024 16:26:24 +0800 Subject: [PATCH 082/216] Reworked README.md (#1470) * Rework README.md Co-authored-by: Fangjun Kuang --------- Co-authored-by: Fangjun Kuang --- README.md | 439 +++++++++++++++++++++++++----------------------------- 1 file changed, 201 insertions(+), 238 deletions(-) diff --git a/README.md b/README.md index 15e9e17e6..61920be65 100644 --- a/README.md +++ b/README.md @@ -2,46 +2,83 @@ -## Introduction +# Introduction -icefall contains ASR recipes for various datasets -using . +The icefall peoject contains speech related recipes for various datasets +using [k2-fsa](https://github.com/k2-fsa/k2) and [lhotse](https://github.com/lhotse-speech/lhotse). -You can use to deploy models -trained with icefall. +You can use [sherpa](https://github.com/k2-fsa/sherpa), [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn) or [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) for deployment with models +in icefall; these frameworks also support models not included in icefall; please refer to respective documents for more details. You can try pre-trained models from within your browser without the need -to download or install anything by visiting -See for more details. +to download or install anything by visiting this [huggingface space](https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition). +Please refer to [document](https://k2-fsa.github.io/icefall/huggingface/spaces.html) for more details. -## Installation +# Installation -Please refer to +Please refer to [document](https://icefall.readthedocs.io/en/latest/installation/index.html) for installation. -## Recipes +# Recipes -Please refer to -for more information. +Please refer to [document](https://icefall.readthedocs.io/en/latest/recipes/index.html) +for more details. -We provide the following recipes: +## ASR: Automatic Speech Recognition +### Supported Datasets - [yesno][yesno] - - [LibriSpeech][librispeech] - - [GigaSpeech][gigaspeech] - - [AMI][ami] + + - [Aidatatang_200zh][aidatatang_200zh] - [Aishell][aishell] - [Aishell2][aishell2] - [Aishell4][aishell4] + - [Alimeeting][alimeeting] + - [AMI][ami] + - [CommonVoice][commonvoice] + - [Corpus of Spontaneous Japanese][csj] + - [GigaSpeech][gigaspeech] + - [LibriCSS][libricss] + - [LibriSpeech][librispeech] + - [Libriheavy][libriheavy] + - [Multi-Dialect Broadcast News Arabic Speech Recognition][mgb2] + - [PeopleSpeech][peoplespeech] + - [SPGISpeech][spgispeech] + - [Switchboard][swbd] - [TIMIT][timit] - [TED-LIUM3][tedlium3] - - [Aidatatang_200zh][aidatatang_200zh] - - [WenetSpeech][wenetspeech] - - [Alimeeting][alimeeting] - - [Switchboard][swbd] - [TAL_CSASR][tal_csasr] + - [Voxpopuli][voxpopuli] + - [XBMU-AMDO31][xbmu-amdo31] + - [WenetSpeech][wenetspeech] + +More datasets will be added in the future. -### yesno +### Supported Models + +The [LibriSpeech][librispeech] recipe supports the most comprehensive set of models, you are welcome to try them out. + +#### CTC + - TDNN LSTM CTC + - Conformer CTC + - Zipformer CTC + +#### MMI + - Conformer MMI + - Zipformer MMI + +#### Transducer + - Conformer-based Encoder + - LSTM-based Encoder + - Zipformer-based Encoder + - LSTM-based Predictor + - [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/) + +If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details. + +We would like to highlight the performance of some of the recipes here. + +### [yesno][yesno] This is the simplest ASR recipe in `icefall` and can be run on CPU. Training takes less than 30 seconds and gives you the following WER: @@ -52,350 +89,264 @@ Training takes less than 30 seconds and gives you the following WER: We provide a Colab notebook for this recipe: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing) -### LibriSpeech +### [LibriSpeech][librispeech] -Please see +Please see [RESULTS.md](https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md) for the **latest** results. -We provide 5 models for this recipe: - -- [conformer CTC model][LibriSpeech_conformer_ctc] -- [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc] -- [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer] -- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless] -- [Transducer: Zipformer encoder + Embedding decoder][LibriSpeech_zipformer] - -#### Conformer CTC Model - -The best WER we currently have is: +#### [Conformer CTC](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc) | | test-clean | test-other | |-----|------------|------------| | WER | 2.42 | 5.73 | -We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing) -#### TDNN LSTM CTC Model - -The WER for this model is: +#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/tdnn_lstm_ctc) | | test-clean | test-other | |-----|------------|------------| | WER | 6.59 | 17.69 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing) -#### Transducer: Conformer encoder + LSTM decoder +#### [Transducer (Conformer Encoder + LSTM Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/transducer) -Using Conformer as encoder and LSTM as decoder. +| | test-clean | test-other | +|---------------|------------|------------| +| greedy search | 3.07 | 7.51 | -The best WER with greedy search is: +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) -| | test-clean | test-other | -|-----|------------|------------| -| WER | 3.07 | 7.51 | +#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/transducer) -We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) - -#### Transducer: Conformer encoder + Embedding decoder - -Using Conformer as encoder. The decoder consists of 1 embedding layer -and 1 convolutional layer. - -The best WER using modified beam search with beam size 4 is: - -| | test-clean | test-other | -|-----|------------|------------| -| WER | 2.56 | 6.27 | - -Note: No auxiliary losses are used in the training and no LMs are used -in the decoding. - -We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) +| | test-clean | test-other | +|---------------------------------------|------------|------------| +| modified_beam_search (`beam_size=4`) | 2.56 | 6.27 | -#### k2 pruned RNN-T +We provide a Colab notebook to run test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) + + +#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/zipformer) + +WER (modified_beam_search `beam_size=4` unless further stated) + +1. LibriSpeech-960hr | Encoder | Params | test-clean | test-other | epochs | devices | |-----------------|--------|------------|------------|---------|------------| -| zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 | -| zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 | -| zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 | -| zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 | +| Zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 | +| Zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 | +| Zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 | +| Zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 | -Note: No auxiliary losses are used in the training and no LMs are used -in the decoding. +2. LibriSpeech-960hr + GigaSpeech -#### k2 pruned RNN-T + GigaSpeech - -| | test-clean | test-other | -|-----|------------|------------| -| WER | 1.78 | 4.08 | - -Note: No auxiliary losses are used in the training and no LMs are used -in the decoding. - -#### k2 pruned RNN-T + GigaSpeech + CommonVoice - -| | test-clean | test-other | -|-----|------------|------------| -| WER | 1.90 | 3.98 | - -Note: No auxiliary losses are used in the training and no LMs are used -in the decoding. +| Encoder | Params | test-clean | test-other | +|-----------------|--------|------------|------------| +| Zipformer | 65.5M | 1.78 | 4.08 | -### GigaSpeech +3. LibriSpeech-960hr + GigaSpeech + CommonVoice -We provide three models for this recipe: +| Encoder | Params | test-clean | test-other | +|-----------------|--------|------------|------------| +| Zipformer | 65.5M | 1.90 | 3.98 | -- [Conformer CTC model][GigaSpeech_conformer_ctc] -- [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. -- [Transducer: Zipformer encoder + Embedding decoder][GigaSpeech_zipformer] -#### Conformer CTC +### [GigaSpeech][gigaspeech] + +#### [Conformer CTC](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/conformer_ctc) | | Dev | Test | |-----|-------|-------| | WER | 10.47 | 10.58 | -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss +#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/pruned_transducer_stateless2) + +Conformer Encoder + Stateless Predictor + k2 Pruned RNN-T Loss | | Dev | Test | |----------------------|-------|-------| -| greedy search | 10.51 | 10.73 | -| fast beam search | 10.50 | 10.69 | -| modified beam search | 10.40 | 10.51 | +| greedy_search | 10.51 | 10.73 | +| fast_beam_search | 10.50 | 10.69 | +| modified_beam_search | 10.40 | 10.51 | -#### Transducer: Zipformer encoder + Embedding decoder +#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/zipformer) | | Dev | Test | |----------------------|-------|-------| -| greedy search | 10.31 | 10.50 | -| fast beam search | 10.26 | 10.48 | -| modified beam search | 10.25 | 10.38 | +| greedy_search | 10.31 | 10.50 | +| fast_beam_search | 10.26 | 10.48 | +| modified_beam_search | 10.25 | 10.38 | -### Aishell +### [Aishell][aishell] -We provide three models for this recipe: [conformer CTC model][Aishell_conformer_ctc], -[TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc], and [Transducer Stateless Model][Aishell_pruned_transducer_stateless7], - -#### Conformer CTC Model - -The best CER we currently have is: - -| | test | -|-----|------| -| CER | 4.26 | - -#### TDNN LSTM CTC Model - -The CER for this model is: +#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/tdnn_lstm_ctc) | | test | |-----|-------| | CER | 10.16 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing) -#### Transducer Stateless Model - -The best CER we currently have is: +#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/transducer_stateless) | | test | |-----|------| | CER | 4.38 | -We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing) + +#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/zipformer) + +WER (modified_beam_search `beam_size=4`) + +| Encoder | Params | dev | test | epochs | +|-----------------|--------|-----|------|---------| +| Zipformer | 73.4M | 4.13| 4.40 | 55 | +| Zipformer-small | 30.2M | 4.40| 4.67 | 55 | +| Zipformer-large | 157.3M | 4.03| 4.28 | 56 | -### Aishell2 +### [Aishell4][aishell4] -We provide one model for this recipe: [Transducer Stateless Model][Aishell2_pruned_transducer_stateless5]. - -#### Transducer Stateless Model - -The best WER we currently have is: - -| | dev-ios | test-ios | -|-----|------------|------------| -| WER | 5.32 | 5.56 | - - -### Aishell4 - -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5]. - -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets) - -The best CER we currently have is: +#### [Transducer (pruned_transducer_stateless5)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell4/ASR/pruned_transducer_stateless5) +1 Trained with all subsets: | | test | |-----|------------| | CER | 29.08 | - -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) -### TIMIT +### [TIMIT][timit] -We provide two models for this recipe: [TDNN LSTM CTC model][TIMIT_tdnn_lstm_ctc] -and [TDNN LiGRU CTC model][TIMIT_tdnn_ligru_ctc]. +#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_lstm_ctc) -#### TDNN LSTM CTC Model - -The best PER we currently have is: - -||TEST| -|--|--| +| |TEST| +|---|----| |PER| 19.71% | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Hs9DA4V96uapw_30uNp32OMJgkuR5VVd?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Hs9DA4V96uapw_30uNp32OMJgkuR5VVd?usp=sharing) -#### TDNN LiGRU CTC Model +#### [TDNN LiGRU CTC](https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_ligru_ctc) -The PER for this model is: - -||TEST| -|--|--| +| |TEST| +|---|----| |PER| 17.66% | -We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) -### TED-LIUM3 +### [TED-LIUM3][tedlium3] -We provide two models for this recipe: [Transducer Stateless: Conformer encoder + Embedding decoder][TED-LIUM3_transducer_stateless] and [Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TED-LIUM3_pruned_transducer_stateless]. +#### [Transducer (Conformer Encoder + Embedding Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless) -#### Transducer Stateless: Conformer encoder + Embedding decoder - -The best WER using modified beam search with beam size 4 is: - -| | dev | test | -|-----|-------|--------| -| WER | 6.91 | 6.33 | - -Note: No auxiliary losses are used in the training and no LMs are used in the decoding. - -We provide a Colab notebook to run a pre-trained Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing) - -#### Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss - -The best WER using modified beam search with beam size 4 is: - -| | dev | test | -|-----|-------|--------| -| WER | 6.77 | 6.14 | - -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing) +| | dev | test | +|--------------------------------------|-------|--------| +| modified_beam_search (`beam_size=4`) | 6.91 | 6.33 | -### Aidatatang_200zh +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing) -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aidatatang_200zh_pruned_transducer_stateless2]. +#### [Transducer (pruned_transducer_stateless)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/pruned_transducer_stateless) -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss +| | dev | test | +|--------------------------------------|-------|--------| +| modified_beam_search (`beam_size=4`) | 6.77 | 6.14 | + +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing) + + +### [Aidatatang_200zh][aidatatang_200zh] + +#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2) | | Dev | Test | |----------------------|-------|-------| -| greedy search | 5.53 | 6.59 | -| fast beam search | 5.30 | 6.34 | -| modified beam search | 5.27 | 6.33 | +| greedy_search | 5.53 | 6.59 | +| fast_beam_search | 5.30 | 6.34 | +| modified_beam_search | 5.27 | 6.33 | -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing) -### WenetSpeech +### [WenetSpeech][wenetspeech] -We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5]. - -#### Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset, offline ASR) +#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless2) | | Dev | Test-Net | Test-Meeting | |----------------------|-------|----------|--------------| -| greedy search | 7.80 | 8.75 | 13.49 | -| modified beam search| 7.76 | 8.71 | 13.41 | -| fast beam search | 7.94 | 8.74 | 13.80 | +| greedy_search | 7.80 | 8.75 | 13.49 | +| fast_beam_search | 7.94 | 8.74 | 13.80 | +| modified_beam_search | 7.76 | 8.71 | 13.41 | + +We provide a Colab notebook to run the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing) + +#### [Transducer **Streaming** (pruned_transducer_stateless5) ](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless5) -#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset) -**Streaming**: | | Dev | Test-Net | Test-Meeting | |----------------------|-------|----------|--------------| | greedy_search | 8.78 | 10.12 | 16.16 | -| modified_beam_search | 8.53| 9.95 | 15.81 | | fast_beam_search| 9.01 | 10.47 | 16.28 | +| modified_beam_search | 8.53| 9.95 | 15.81 | -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless2 model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing) -### Alimeeting +### [Alimeeting][alimeeting] -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Alimeeting_pruned_transducer_stateless2]. - -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with far subset) +#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/alimeeting/ASR/pruned_transducer_stateless2) | | Eval | Test-Net | |----------------------|--------|----------| -| greedy search | 31.77 | 34.66 | -| fast beam search | 31.39 | 33.02 | -| modified beam search | 30.38 | 34.25 | +| greedy_search | 31.77 | 34.66 | +| fast_beam_search | 31.39 | 33.02 | +| modified_beam_search | 30.38 | 34.25 | -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing) -### TAL_CSASR +### [TAL_CSASR][tal_csasr] -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5]. -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss +#### [Transducer (pruned_transducer_stateless5)](https://github.com/k2-fsa/icefall/tree/master/egs/tal_csasr/ASR/pruned_transducer_stateless5) The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English): |decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en | |--|--|--|--|--|--|--| |greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| -|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 | |fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77| +|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 | -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing) -## Deployment with C++ +## TTS: Text-to-Speech -Once you have trained a model in icefall, you may want to deploy it with C++, -without Python dependencies. +### Supported Datasets -Please refer to the documentation - + - [LJSpeech][ljspeech] + - [VCTK][vctk] + +### Supported Models + + - [VITS](https://arxiv.org/abs/2106.06103) + +# Deployment with C++ + +Once you have trained a model in icefall, you may want to deploy it with C++ without Python dependencies. + +Please refer to the [document](https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/librispeech/conformer_ctc.html#deployment-with-c) for how to do this. We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++. Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing) -[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc -[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc -[LibriSpeech_transducer]: egs/librispeech/ASR/transducer -[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless -[LibriSpeech_zipformer]: egs/librispeech/ASR/zipformer -[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc -[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc -[Aishell_pruned_transducer_stateless7]: egs/aishell/ASR/pruned_transducer_stateless7_bbpe -[Aishell2_pruned_transducer_stateless5]: egs/aishell2/ASR/pruned_transducer_stateless5 -[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5 -[TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc -[TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc -[TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless -[TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless -[GigaSpeech_conformer_ctc]: egs/gigaspeech/ASR/conformer_ctc -[GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2 -[GigaSpeech_zipformer]: egs/gigaspeech/ASR/zipformer -[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2 -[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2 -[WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5 -[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2 -[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5 [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR [aishell]: egs/aishell/ASR @@ -411,3 +362,15 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [ami]: egs/ami [swbd]: egs/swbd/ASR [k2]: https://github.com/k2-fsa/k2 +[commonvoice]: egs/commonvoice/ASR +[csj]: egs/csj/ASR +[libricss]: egs/libricss/SURT +[libriheavy]: egs/libriheavy/ASR +[mgb2]: egs/mgb2/ASR +[peoplespeech]: egs/peoplespeech/ASR +[spgispeech]: egs/spgispeech/ASR +[voxpopuli]: egs/voxpopuli/ASR +[xbmu-amdo31]: egs/xbmu-amdo31/ASR + +[vctk]: egs/vctk/TTS +[ljspeech]: egs/ljspeech/TTS \ No newline at end of file From 559ed150bb73e3e2e89c703bac4c37744e516e8e Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 23 Jan 2024 22:51:09 +0800 Subject: [PATCH 083/216] Fix typo (#1471) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 61920be65..f92c85ad4 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # Introduction -The icefall peoject contains speech related recipes for various datasets +The icefall project contains speech-related recipes for various datasets using [k2-fsa](https://github.com/k2-fsa/k2) and [lhotse](https://github.com/lhotse-speech/lhotse). You can use [sherpa](https://github.com/k2-fsa/sherpa), [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn) or [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) for deployment with models @@ -373,4 +373,4 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [xbmu-amdo31]: egs/xbmu-amdo31/ASR [vctk]: egs/vctk/TTS -[ljspeech]: egs/ljspeech/TTS \ No newline at end of file +[ljspeech]: egs/ljspeech/TTS From 9c494a3329d531e4ed10117ec0b6f244d0a61ce3 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 25 Jan 2024 18:41:43 +0800 Subject: [PATCH 084/216] typos fixed (#1472) --- README.md | 12 ++++++------ .../ASR/local/compute_fbank_peoples_speech_splits.py | 4 ++-- .../ASR/local/compute_fbank_wenetspeech_splits.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f92c85ad4..cc817702b 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt | | test-clean | test-other | |---------------|------------|------------| -| greedy search | 3.07 | 7.51 | +| greedy_search | 3.07 | 7.51 | We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) @@ -127,7 +127,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt | modified_beam_search (`beam_size=4`) | 2.56 | 6.27 | -We provide a Colab notebook to run test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) #### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/zipformer) @@ -147,14 +147,14 @@ WER (modified_beam_search `beam_size=4` unless further stated) | Encoder | Params | test-clean | test-other | |-----------------|--------|------------|------------| -| Zipformer | 65.5M | 1.78 | 4.08 | +| Zipformer | 65.5M | 1.78 | 4.08 | 3. LibriSpeech-960hr + GigaSpeech + CommonVoice | Encoder | Params | test-clean | test-other | |-----------------|--------|------------|------------| -| Zipformer | 65.5M | 1.90 | 3.98 | +| Zipformer | 65.5M | 1.90 | 3.98 | ### [GigaSpeech][gigaspeech] @@ -246,7 +246,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt ### [TED-LIUM3][tedlium3] -#### [Transducer (Conformer Encoder + Embedding Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless) +#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless) | | dev | test | |--------------------------------------|-------|--------| @@ -287,7 +287,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt | fast_beam_search | 7.94 | 8.74 | 13.80 | | modified_beam_search | 7.76 | 8.71 | 13.41 | -We provide a Colab notebook to run the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing) #### [Transducer **Streaming** (pruned_transducer_stateless5) ](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless5) diff --git a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py index c2ab3d07d..6f05b9f8c 100755 --- a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py +++ b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py @@ -67,14 +67,14 @@ def get_args(): "--start", type=int, default=0, - help="Process pieces starting from this number (inclusive).", + help="Process pieces starting from this number (included).", ) parser.add_argument( "--stop", type=int, default=-1, - help="Stop processing pieces until this number (exclusive).", + help="Stop processing pieces until this number (excluded).", ) return parser.parse_args() diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index 99d39bbdc..a87801462 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -78,14 +78,14 @@ def get_parser(): "--start", type=int, default=0, - help="Process pieces starting from this number (inclusive).", + help="Process pieces starting from this number (included).", ) parser.add_argument( "--stop", type=int, default=-1, - help="Stop processing pieces until this number (exclusive).", + help="Stop processing pieces until this number (excluded).", ) return parser From c401a2646b347bf1fff0c2ce1a4ee13b0f482448 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 26 Jan 2024 15:50:11 +0800 Subject: [PATCH 085/216] minor fix of zipformer/optim.py (#1474) --- egs/librispeech/ASR/zipformer/optim.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 714d8db9a..aaffbfed5 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -298,11 +298,14 @@ class ScaledAdam(BatchedOptimizer): # case 2 or case 4 # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: - assert "named_params" in cur_group - name_list = [x[0] for x in cur_group["named_params"]] - p_list = [x[1] for x in cur_group["named_params"]] - del cur_group["named_params"] - cur_group["params"] = p_list + if "named_params" in cur_group: + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + else: + assert "params" in cur_group + name_list = ["foo" for _ in cur_group["params"]] param_groups.append(cur_group) param_groups_names.append(name_list) From 8d39f9508bd5b27627165696758cad2e96dca20b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 26 Jan 2024 19:18:33 +0800 Subject: [PATCH 086/216] Fix torchscript export to use tokens.txt instead of lang_dir (#1475) --- .../pruned_transducer_stateless2/export.py | 25 +++++++++------ .../ASR/pruned_transducer_stateless2/lstmp.py | 1 + .../scaling_converter.py | 1 + .../pruned_transducer_stateless2/export.py | 24 ++++++++------ .../ASR/pruned_transducer_stateless2/lstmp.py | 1 + .../scaling_converter.py | 1 + .../ASR/lstm_transducer_stateless2/export.py | 7 ++--- .../pruned_stateless_emformer_rnnt2/export.py | 1 + .../pruned_transducer_stateless/decoder.py | 2 +- .../export.py | 8 ++--- .../pruned_transducer_stateless5/export.py | 31 +++++++++---------- .../ASR/pruned_transducer_stateless5/lstmp.py | 1 + .../scaling_converter.py | 1 + .../pruned_transducer_stateless2/export.py | 23 +++++++------- .../pruned_transducer_stateless5/export.py | 25 +++++++-------- 15 files changed, 83 insertions(+), 69 deletions(-) mode change 100644 => 100755 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py create mode 120000 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py create mode 120000 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py create mode 120000 egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py create mode 120000 egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py create mode 120000 egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py create mode 120000 egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py old mode 100644 new mode 100755 index e348f7b2b..5179bfa1c --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors @@ -20,7 +21,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 29 \ --avg 19 @@ -45,12 +46,13 @@ import argparse import logging from pathlib import Path +import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -85,10 +87,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -122,10 +124,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -152,6 +158,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index b6190e8a6..4a44f7bcb 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,13 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +99,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,12 +136,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -183,6 +186,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 5712da25e..aeed58dec 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -218,10 +218,9 @@ def export_decoder_model_jit_trace( decoder_filename: The filename to save the exported model. """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + # TODO(fangjun): Change the function name since we are actually using + # torch.jit.script instead of torch.jit.trace + traced_model = torch.jit.script(decoder_model) traced_model.save(decoder_filename) logging.info(f"Saved to {decoder_filename}") diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index ec2c9d580..e42a5c6ef 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -159,6 +159,7 @@ def main(): # Load id of the token and the vocab size params.blank_id = token_table[""] + params.unk_id = token_table[""] params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 03847b449..b961611f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -91,7 +91,7 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ - embedding_out = self.embedding(y) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 59a7eb589..67041012d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -87,7 +87,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e ln -s pretrained.pt epoch-999.pt ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --use-averaged-model False \ --epoch 999 \ --avg 1 \ @@ -113,7 +113,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e ln -s pretrained.pt epoch-999.pt ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --use-averaged-model False \ --epoch 999 \ --avg 1 \ diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py index bc33dd160..0f6190a41 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py @@ -23,7 +23,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ + --tokens ./data/lang_char/tokens.txt \ --epoch 30 \ --avg 24 \ --use-averaged-model True @@ -50,8 +50,9 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -60,8 +61,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -118,13 +118,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -160,13 +157,14 @@ def main(): logging.info(f"device: {device}") - bpe_model = params.lang_dir + "/bpe.model" - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -256,6 +254,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py index 5d25daf5e..2f6ef488e 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py @@ -24,7 +24,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 10 \ --avg 2 \ --jit 1 @@ -47,7 +47,7 @@ for how to use them. ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 10 \ --avg 2 \ --jit-trace 1 @@ -63,7 +63,7 @@ Check ./jit_pretrained.py for usage. ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 10 \ --avg 2 @@ -91,14 +91,14 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -133,10 +133,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -313,10 +313,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py index cb541070e..5ff1f4a3b 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py @@ -20,7 +20,7 @@ Usage for offline: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 4 \ --avg 1 @@ -28,7 +28,7 @@ It will generate a file exp_dir/pretrained.pt for offline ASR. ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 4 \ --avg 1 \ --jit True @@ -38,7 +38,7 @@ It will generate a file exp_dir/cpu_jit.pt for offline ASR. Usage for streaming: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 7 \ --avg 1 @@ -46,7 +46,7 @@ It will generate a file exp_dir/pretrained.pt for streaming ASR. ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 7 \ --avg 1 \ --jit True @@ -73,13 +73,13 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -114,10 +114,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -152,10 +152,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) From 1c30847947f9d8b1416ef3e70408c07eab807f3d Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Sat, 27 Jan 2024 00:32:30 +0800 Subject: [PATCH 087/216] Whisper Fine-tuning Recipe on Aishell1 (#1466) * add decode seamlessm4t * add requirements * add decoding with avg model * add token files * add custom tokenizer * support deepspeed to finetune large model * support large-v3 * add model saving * using monkey patch to replace models * add manifest dir option --- egs/aishell/ASR/README.md | 7 + egs/aishell/ASR/RESULTS.md | 67 +- .../ASR/local/compute_fbank_aishell.py | 45 +- egs/aishell/ASR/prepare.sh | 13 + egs/aishell/ASR/whisper/asr_datamodule.py | 1 + egs/aishell/ASR/whisper/decode.py | 503 ++++++++++ egs/aishell/ASR/whisper/ds_config_zero1.json | 38 + egs/aishell/ASR/whisper/label_smoothing.py | 1 + egs/aishell/ASR/whisper/optim.py | 1 + egs/aishell/ASR/whisper/requirements.txt | 10 + egs/aishell/ASR/whisper/train.py | 927 ++++++++++++++++++ .../whisper_encoder_forward_monkey_patch.py | 29 + .../ASR/local/compute_fbank_musan.py | 59 +- icefall/dist.py | 2 +- 14 files changed, 1682 insertions(+), 21 deletions(-) create mode 120000 egs/aishell/ASR/whisper/asr_datamodule.py create mode 100755 egs/aishell/ASR/whisper/decode.py create mode 100644 egs/aishell/ASR/whisper/ds_config_zero1.json create mode 120000 egs/aishell/ASR/whisper/label_smoothing.py create mode 120000 egs/aishell/ASR/whisper/optim.py create mode 100755 egs/aishell/ASR/whisper/requirements.txt create mode 100755 egs/aishell/ASR/whisper/train.py create mode 100644 egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index 176f065e5..b54719162 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -24,3 +24,10 @@ The following table lists the differences among them. The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. + +# Whisper + +Recipe to finetune large pretrained models +| | Encoder | Decoder | Comment | +|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------| +| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index ff9504274..46d712fb2 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,5 +1,63 @@ ## Results +### Aishell training results (Fine-tuning Pretrained Models) +#### Whisper +[./whisper](./whisper) +##### fine-tuning results on Aishell test set on whisper medium, large-v2, large-v3 + +| | test (before fine-tuning) | test (after fine-tuning) | comment | +|------------------------|------|------|-----------------------------------------| +| medium | 7.23 | 3.27 | --epoch 10 --avg 4, ddp | +| large-v2 | 6.56 | 2.47 | --epoch 10 --avg 6, deepspeed zero stage1 | +| large-v3 | 6.06 | 2.84 | --epoch 5 --avg 3, deepspeed zero stage1 | + +Command for training is: +```bash +pip install -r whisper/requirements.txt + +./prepare.sh --stage 30 --stop_stage 30 + +#fine-tuning with deepspeed zero stage 1 +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --deepspeed \ + --deepspeed_config ./whisper/ds_config_zero1.json + +# fine-tuning with ddp +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_medium \ + --base-lr 1e-5 \ + --model-name medium +``` + +Command for decoding using fine-tuned models: +```bash +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --beam-size 10 --max-duration 50 +``` +Command for decoding using pretrained models (before fine-tuning): +```bash +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 10 --max-duration 50 +``` +Fine-tuned models, training logs, decoding logs, tensorboard and decoding results +are available at + + ### Aishell training result (Stateless Transducer) #### Zipformer (Byte-level BPE) @@ -71,7 +129,7 @@ It's reworked Zipformer with Pruned RNNT loss. Command for training is: ```bash -./prepare.sh +./prepare.sh export CUDA_VISIBLE_DEVICES="0,1" @@ -136,7 +194,7 @@ export CUDA_VISIBLE_DEVICES="0,1" --feedforward-dim 512,768,768,768,768,768 \ --encoder-dim 192,256,256,256,256,256 \ --encoder-unmasked-dim 192,192,192,192,192,192 \ - --max-duration 1200 + --max-duration 1200 ``` Command for decoding is: @@ -186,7 +244,7 @@ export CUDA_VISIBLE_DEVICES="0,1" --feedforward-dim 512,768,1536,2048,1536,768 \ --encoder-dim 192,256,512,768,512,256 \ --encoder-unmasked-dim 192,192,256,320,256,192 \ - --max-duration 800 + --max-duration 800 ``` Command for decoding is: @@ -202,7 +260,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do --num-encoder-layers 2,2,4,5,4,2 \ --feedforward-dim 512,768,1536,2048,1536,768 \ --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 + --encoder-unmasked-dim 192,192,256,320,256,192 done ``` @@ -755,7 +813,6 @@ python3 ./transducer_stateless/decode.py \ --max-sym-per-frame 3 ``` -### Aishell training results (Transducer-stateless) #### 2022-02-18 (Pingfeng Luo) : The tensorboard log for training is available at And pretrained model is available at diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index c7000da1c..3c48f0aa1 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,9 +49,14 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): +def compute_fbank_aishell( + num_mel_bins: int = 80, + perturb_speed: bool = False, + whisper_fbank: bool = False, + output_dir: str = "data/fbank", +): src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + output_dir = Path(output_dir) num_jobs = min(15, os.cpu_count()) dataset_parts = ( @@ -68,8 +80,12 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): list(manifests.keys()), dataset_parts, ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -82,7 +98,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -111,6 +127,18 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) return parser.parse_args() @@ -121,5 +149,8 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, ) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 9f73a2073..b7be89bc8 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -376,3 +376,16 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then --vocab-size 4336 \ --master-port 12345 fi + +# whisper large-v3 using 128 mel bins, others using 80 mel bins +whisper_mel_bins=80 +output_dir=data/fbank_whisper +if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then + log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning" + if [ ! -f $output_dir/.aishell.whisper.done ]; then + mkdir -p $output_dir + ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir + ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir + touch $output_dir/.aishell.whisper.done + fi +fi diff --git a/egs/aishell/ASR/whisper/asr_datamodule.py b/egs/aishell/ASR/whisper/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/aishell/ASR/whisper/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py new file mode 100755 index 000000000..7f841dcb7 --- /dev/null +++ b/egs/aishell/ASR/whisper/decode.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Wei Kang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# Command for decoding using fine-tuned models: +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --manifest-dir data/fbank_whisper \ + --beam-size 10 --max-duration 50 + +# Command for decoding using pretrained models (before fine-tuning): + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --manifest-dir data/fbank_whisper \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 10 --max-duration 50 + +""" + +import argparse +import logging +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +import whisper +from asr_datamodule import AishellAsrDataModule +from tn.chinese.normalizer import Normalizer +from whisper.normalizers import BasicTextNormalizer +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward +from zhconv import convert + +from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def average_checkpoints( + filenames: List[Path], device: torch.device = torch.device("cpu") +) -> dict: + """Average a list of checkpoints. + The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. + + Args: + filenames: + Filenames of the checkpoints to be averaged. We assume all + checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. + Returns: + Return a dict (i.e., state_dict) which is the average of all + model state dicts contained in the checkpoints. + """ + n = len(filenames) + + if "model" in torch.load(filenames[0], map_location=device): + avg = torch.load(filenames[0], map_location=device)["model"] + else: + avg = torch.load(filenames[0], map_location=device) + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + + for i in range(1, n): + if "model" in torch.load(filenames[i], map_location=device): + state_dict = torch.load(filenames[i], map_location=device)["model"] + else: + state_dict = torch.load(filenames[i], map_location=device) + for k in uniqued_names: + avg[k] += state_dict[k] + + for k in uniqued_names: + if avg[k].is_floating_point(): + avg[k] /= n + else: + avg[k] //= n + + return avg + + +def remove_punctuation(text: str or List[str]): + """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py + + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings without any punctuation. + """ + punctuation = "!,.;:?、!,。;:?《》 " + if isinstance(text, str): + text = re.sub(r"[{}]+".format(punctuation), "", text).strip() + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = re.sub(r"[{}]+".format(punctuation), "", t).strip() + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type {type(text)}") + + +def to_simple(text: str or List[str]): + """Convert traditional Chinese to simplified Chinese. + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings converted to simplified Chinese. + """ + if isinstance(text, str): + text = convert(text, "zh-cn") + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = convert(t, "zh-cn") + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type{type(text)}") + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=-1, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="beam-search", + help="""Decoding method. + Supported values are: + - beam-search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=1, + help="beam size for beam search decoding", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--remove-whisper-encoder-input-length-restriction", + type=str2bool, + default=True, + help="replace whisper encoder forward method to remove input length restriction", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: "beam-search" + - value: A list of lists. Each sublist is a list of token IDs. + Args: + params: + It is returned by :func:`get_params`. + model: + The neural model. + batch: + It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. + Returns: + Return a dict, whose key may be "beam-search". + """ + dtype = torch.float16 + device = torch.device("cuda") + + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device, dtype=dtype).transpose(1, 2) + if not params.remove_whisper_encoder_input_length_restriction: + T = 3000 + if feature.shape[2] < T: + feature = torch.cat( + [ + feature, + torch.zeros( + feature.shape[0], feature.shape[1], T - feature.shape[2] + ).to(device, dtype=dtype), + ], + 2, + ) + + supervisions = batch["supervisions"] + feature_len = supervisions["num_frames"] + feature_len = feature_len.to(device, dtype=dtype) + results = model.decode(feature, params.decoding_options) + hyps = [result.text for result in results] + + hyps = remove_punctuation(hyps) + hyps = to_simple(hyps) + hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + + return {"beam-search": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + The dataloader. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "beam-search". + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + batch=batch, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + setup_logger( + f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + ) + + options = whisper.DecodingOptions( + task="transcribe", + language="zh", + without_timestamps=True, + beam_size=params.beam_size, + ) + params.decoding_options = options + params.cleaner = BasicTextNormalizer() + params.normalizer = Normalizer() + + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + + logging.info(f"device: {device}") + + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + if params.epoch > 0: + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + model.load_state_dict(average_checkpoints(filenames)) + else: + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + # save checkpoints + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(model.state_dict(), filename) + else: + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) + test_dl = aishell.test_dataloaders(aishell.test_cuts()) + test_sets = ["valid", "test"] + test_dls = [valid_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/whisper/ds_config_zero1.json b/egs/aishell/ASR/whisper/ds_config_zero1.json new file mode 100644 index 000000000..bf8cc0452 --- /dev/null +++ b/egs/aishell/ASR/whisper/ds_config_zero1.json @@ -0,0 +1,38 @@ +{ + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 100, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 0.01 + }, + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-5 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-5, + "warmup_num_steps": 100 + } + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": 5, + "steps_per_print": 50, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": false +} diff --git a/egs/aishell/ASR/whisper/label_smoothing.py b/egs/aishell/ASR/whisper/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/aishell/ASR/whisper/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/optim.py b/egs/aishell/ASR/whisper/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/aishell/ASR/whisper/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt new file mode 100755 index 000000000..0708f2344 --- /dev/null +++ b/egs/aishell/ASR/whisper/requirements.txt @@ -0,0 +1,10 @@ +k2 +kaldialign +git+https://github.com/lhotse-speech/lhotse +sentencepiece +tensorboard +librosa +git+https://github.com/yuekaizhang/whisper.git +zhconv +WeTextProcessing +deepspeed diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py new file mode 100755 index 000000000..d16793eb2 --- /dev/null +++ b/egs/aishell/ASR/whisper/train.py @@ -0,0 +1,927 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +#fine-tuning with deepspeed zero stage 1 +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --manifest-dir data/fbank_whisper \ + --deepspeed \ + --deepspeed_config ./whisper/ds_config_zero1.json + +# fine-tuning with ddp +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_medium \ + --manifest-dir data/fbank_whisper \ + --base-lr 1e-5 \ + --model-name medium +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import deepspeed +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import whisper +from asr_datamodule import AishellAsrDataModule +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from label_smoothing import LabelSmoothingLoss +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.functional import pad as pad_tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import update_averaged_model +from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=1e-5, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser = deepspeed.add_config_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - frame_shift_ms: The frame shift in milliseconds. + - allowed_excess_duration_ratio: The allowed excess duration ratio. + - best_train_loss: The best training loss so far. + - best_valid_loss: The best validation loss so far. + - best_train_epoch: The epoch where the best training loss is achieved. + - best_valid_epoch: The epoch where the best validation loss is achieved. + - batch_idx_train: The batch index of the current batch. + - log_interval: Log training stats every `log_interval` batches. + - reset_interval: Reset the stats every `reset_interval` batches. + - valid_interval: Run validation every `valid_interval` batches. + - env_info: The environment information. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "subsampling_factor": 2, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 5000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute the loss for the given batch. + Args: + params: + It is returned by :func:`get_params`. + tokenizer: + The tokenizer used to encode the text. + model: + The model for training. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + Whether it is training. + Returns: + Return a tuple of two elements. The first element is the loss tensor. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: + padding_size = max(tensor.shape[0] for tensor in tensors) + dims = len(tensors[0].shape) + padded_tensors = [] + for tensor in tensors: + padding = [0] * 2 * dims + padding[-1] = padding_size - tensor.shape[0] + padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) + return torch.stack([tensor for tensor in padded_tensors], dim=0) + + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + + assert feature.ndim == 3 + feature = feature.to(device) + feature = feature.transpose(1, 2) # (N, C, T) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + # remove spaces in texts + texts = [text.replace(" ", "") for text in texts] + + text_tokens_list = [ + list(tokenizer.sot_sequence_including_notimestamps) + + tokenizer.encode(text) + + [tokenizer.eot] + for text in texts + ] + # convert it to torch tensor + text_tokens_list = [ + torch.LongTensor(text_tokens) for text_tokens in text_tokens_list + ] + + # 50256 is the index of for all whisper models + prev_outputs_tokens = _batch_tensors( + [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 + ) + target_tokens = _batch_tensors( + [tokens[1:] for tokens in text_tokens_list], pad_value=50256 + ) + target_lengths = torch.LongTensor( + [tokens.shape[0] - 1 for tokens in text_tokens_list] + ) + + decoder_criterion = LabelSmoothingLoss( + ignore_index=50256, label_smoothing=0.1, reduction="sum" + ) + + # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> + ignore_prefix_size = 3 + with torch.set_grad_enabled(is_training): + encoder_out = model.encoder(feature) + text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) + text_logits = text_logits[:, ignore_prefix_size:, :] + target_tokens = target_tokens[:, ignore_prefix_size:] + loss = decoder_criterion(text_logits, target_tokens.to(device)) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + tokenizer=tokenizer, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + if params.deepspeed: + # deepspeed's backward() is different from torch's backward() + # in that it does not accept a loss tensor as input. + # It computes the loss internally. + model.backward(loss) + model.step() + else: + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + and not params.deepspeed + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + if batch_idx % params.log_interval == 0: + try: + cur_lr = scheduler.get_last_lr()[0] + except: # noqa + cur_lr = 0.0 + cur_grad_scale = ( + scaler._scale.item() + if (params.use_fp16 and not params.deepspeed) + else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if (params.use_fp16 and not params.deepspeed) + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info(params) + + logging.info("About to create model") + + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + del model.alignment_heads + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + tokenizer = whisper.tokenizer.get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language="zh", + task="transcribe", + ) + + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + else: + device = torch.device("cpu") + logging.info(f"Device: {device}") + model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if world_size > 1: + if params.deepspeed: + logging.info("Using DeepSpeed") + model, optimizer, _, scheduler = deepspeed.initialize( + args=params, model=model, model_parameters=model.parameters() + ) + else: + logging.info("Using DDP") + setup_dist(use_ddp_launch=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders(aishell.train_cuts()) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + logging.info(f"start training from epoch {params.start_epoch}") + for epoch in range(params.start_epoch, params.num_epochs + 1): + if not params.deepspeed: + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + tokenizer=tokenizer, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.deepspeed: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}", + client_state={}, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", + tag=f"epoch-{params.cur_epoch}", + ) + else: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1 and not params.deepspeed: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = get_world_size() + rank = get_rank() + + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + run(rank=rank, world_size=world_size, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py new file mode 100644 index 000000000..5bfbdce3b --- /dev/null +++ b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py @@ -0,0 +1,29 @@ +import torch +import torch.nn.functional as F +import whisper + + +def forward(self, x: torch.Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +def replace_whisper_encoder_forward(): + """ + This function monkey patches the forward method of the whisper encoder. + To be called before the model is loaded, it changes whisper to process audio with any length < 30s. + """ + whisper.model.AudioEncoder.forward = forward diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 62036467e..d7781687f 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ - +import argparse import logging import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + MonoCut, + WhisperFbank, + WhisperFbankConfig, + combine, +) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -45,11 +54,12 @@ def is_cut_long(c: MonoCut) -> bool: return c.duration > 5 -def compute_fbank_musan(): +def compute_fbank_musan( + num_mel_bins: int = 80, whisper_fbank: bool = False, output_dir: str = "data/fbank" +): src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + output_dir = Path(output_dir) num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 dataset_parts = ( "music", @@ -81,7 +91,12 @@ def compute_fbank_musan(): logging.info("Extracting features for Musan") - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. # create chunks of Musan with duration 5 - 10 seconds @@ -102,8 +117,36 @@ def compute_fbank_musan(): musan_cuts.to_file(musan_cuts_path) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) + return parser.parse_args() + + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() + args = get_args() + compute_fbank_musan( + num_mel_bins=args.num_mel_bins, + whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, + ) diff --git a/icefall/dist.py b/icefall/dist.py index 922f31a2f..ee76e994a 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -22,7 +22,7 @@ from torch import distributed as dist def setup_dist( - rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None + rank=None, world_size=None, master_port=None, use_ddp_launch=False, master_addr=None ): """ rank and world_size are used only if use_ddp_launch is False. From 37b975cac9ac237b71e957dd0770b091124fcb60 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 27 Jan 2024 06:41:56 +0800 Subject: [PATCH 088/216] fixed a CI test for `wenetspeech` (#1476) * Comply to issue #1149 https://github.com/k2-fsa/icefall/issues/1149 --- ...enetspeech-pruned-transducer-stateless2.sh | 6 ++--- .../pruned_transducer_stateless2/export.py | 19 +++++++------- .../pruned_transducer_stateless3/export.py | 19 +++++++------- .../export-onnx.py | 21 +++++++--------- .../ASR/transducer_stateless/export.py | 19 +++++++------- .../transducer_stateless_modified-2/export.py | 18 ++++++------- .../transducer_stateless_modified/export.py | 19 +++++++------- .../pruned_transducer_stateless5/export.py | 20 +++++++-------- .../pruned_transducer_stateless5/export.py | 19 ++++++-------- .../pruned_transducer_stateless2/export.py | 19 +++++++------- .../pruned_transducer_stateless7/export.py | 23 ++++++++--------- .../export-onnx-zh.py | 18 ++++++------- egs/swbd/ASR/conformer_ctc/export.py | 19 +++++++------- egs/tedlium3/ASR/conformer_ctc2/export.py | 18 ++++++------- .../export-onnx.py | 25 ++++++++----------- .../export-onnx-streaming.py | 19 +++++++------- .../export-onnx.py | 18 ++++++------- 17 files changed, 150 insertions(+), 169 deletions(-) diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index a3a2d3080..981b74b76 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -30,7 +30,7 @@ log "Test exporting to ONNX format" ./pruned_transducer_stateless2/export-onnx.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 @@ -38,14 +38,14 @@ log "Export to torchscript model" ./pruned_transducer_stateless2/export.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 ./pruned_transducer_stateless2/export.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --jit-trace 1 diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py index 2ce5cfe69..c2dc0d5f3 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,10 +106,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -136,10 +136,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 723414167..2248c7a08 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -47,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -153,10 +153,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 params.datatang_prob = 0 logging.info(params) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py index 39d988cd0..4981fb71a 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py @@ -49,14 +49,14 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder2 import Decoder +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from zipformer import Zipformer from icefall.checkpoint import ( @@ -65,8 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,12 +122,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -404,9 +401,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 01de5d772..bfd0ecb0c 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -23,7 +23,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,8 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -92,10 +92,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -192,10 +192,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index c1081c32b..4f2c71d18 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -46,6 +46,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,7 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -99,10 +100,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -190,10 +191,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index 3e14ad69c..487748947 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -46,6 +46,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -55,8 +56,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -99,10 +99,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -190,10 +190,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index 8a5be94d0..c92c7ab83 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char + --tokens ./data/lang_char/tokens.txt \ --epoch 25 \ --avg 5 @@ -48,6 +48,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,10 +115,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -154,10 +154,10 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.unk_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index bf9856c60..246820833 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -48,6 +48,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,13 +115,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -157,9 +154,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 8e5cc6075..5dc73c52b 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -20,7 +20,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens ./data/lang_char/tokens.txt \ --epoch 29 \ --avg 18 @@ -45,12 +45,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -85,10 +85,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -122,10 +122,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py index 23a88dd29..8bafaef44 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_char/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_char/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,9 +86,8 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch -import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -98,8 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -156,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -199,10 +197,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index 2a52e2eec..1ce770128 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./lstm_transducer_stateless2/export-onnx-zh.py \ - --lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \ + --tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \ --use-averaged-model 1 \ --epoch 11 \ --avg 1 \ @@ -55,6 +55,7 @@ import logging from pathlib import Path from typing import Dict, Optional, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -70,8 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -441,9 +441,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py index 1bb6277ad..44b2e95d6 100755 --- a/egs/swbd/ASR/conformer_ctc/export.py +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -23,12 +23,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +63,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -105,9 +104,9 @@ def main(): logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 device = torch.device("cpu") if torch.cuda.is_available(): @@ -119,7 +118,7 @@ def main(): num_features=params.feature_dim, nhead=params.nhead, d_model=params.attention_dim, - num_classes=num_classes, + num_classes=params.vocab_size, subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=False, diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py index 009bea230..b5bf911c2 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/export.py +++ b/egs/tedlium3/ASR/conformer_ctc2/export.py @@ -45,6 +45,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from scaling_converter import convert_scaled_to_non_scaled @@ -56,8 +57,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser() -> argparse.ArgumentParser: @@ -118,10 +118,10 @@ def get_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="The lang dir", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -166,9 +166,9 @@ def main(): params = get_params() params.update(vars(args)) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 device = torch.device("cpu") if torch.cuda.is_available(): @@ -182,7 +182,7 @@ def main(): model = Conformer( num_features=params.feature_dim, - num_classes=num_classes, + num_classes=params.vocab_size, subsampling_factor=params.subsampling_factor, d_model=params.dim_model, nhead=params.nhead, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index 140b1d37f..8aea79fe3 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless2/export-onnx.py \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --exp-dir $repo/exp @@ -48,6 +48,7 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -57,14 +58,8 @@ from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -110,10 +105,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -397,9 +392,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 921766ad4..30068d01a 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -58,13 +58,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -from icefall.lexicon import Lexicon import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -74,7 +74,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.lexicon import Lexicon +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -131,10 +132,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -490,9 +491,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py index 037c7adf1..1c9eb8648 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx.py \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,6 +55,7 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -70,8 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -417,9 +417,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) From b07d5472c58eb11bc86cdcf787ecfdb634a55c73 Mon Sep 17 00:00:00 2001 From: Henry Li Xinyuan <57513260+HSTEHSTEHSTE@users.noreply.github.com> Date: Wed, 31 Jan 2024 09:53:36 -0500 Subject: [PATCH 089/216] Implement recipe for Fluent Speech Commands dataset (#1469) --------- Signed-off-by: Xinyuan Li --- egs/fluent_speech_commands/SLU/README.md | 9 + .../SLU/local/compile_hlg.py | 136 ++++ .../SLU/local/compute_fbank_slu.py | 97 +++ .../SLU/local/generate_lexicon.py | 59 ++ .../SLU/local/prepare_lang.py | 371 +++++++++++ egs/fluent_speech_commands/SLU/prepare.sh | 103 +++ egs/fluent_speech_commands/SLU/shared | 1 + .../SLU/transducer/__init__.py | 0 .../SLU/transducer/beam_search.py | 71 ++ .../SLU/transducer/conformer.py | 1 + .../SLU/transducer/decode.py | 346 ++++++++++ .../SLU/transducer/decoder.py | 1 + .../SLU/transducer/encoder_interface.py | 1 + .../SLU/transducer/joiner.py | 1 + .../SLU/transducer/model.py | 1 + .../SLU/transducer/slu_datamodule.py | 289 ++++++++ .../SLU/transducer/subsampling.py | 1 + .../SLU/transducer/test_conformer.py | 1 + .../SLU/transducer/test_decoder.py | 1 + .../SLU/transducer/test_joiner.py | 1 + .../SLU/transducer/test_transducer.py | 1 + .../SLU/transducer/train.py | 625 ++++++++++++++++++ .../SLU/transducer/transformer.py | 1 + icefall/shared/make_kn_lm.py | 9 +- 24 files changed, 2124 insertions(+), 3 deletions(-) create mode 100755 egs/fluent_speech_commands/SLU/README.md create mode 100755 egs/fluent_speech_commands/SLU/local/compile_hlg.py create mode 100755 egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py create mode 100755 egs/fluent_speech_commands/SLU/local/generate_lexicon.py create mode 100755 egs/fluent_speech_commands/SLU/local/prepare_lang.py create mode 100755 egs/fluent_speech_commands/SLU/prepare.sh create mode 120000 egs/fluent_speech_commands/SLU/shared create mode 100755 egs/fluent_speech_commands/SLU/transducer/__init__.py create mode 100755 egs/fluent_speech_commands/SLU/transducer/beam_search.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/conformer.py create mode 100755 egs/fluent_speech_commands/SLU/transducer/decode.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/decoder.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/encoder_interface.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/joiner.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/model.py create mode 100755 egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/subsampling.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_conformer.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_decoder.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_joiner.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_transducer.py create mode 100755 egs/fluent_speech_commands/SLU/transducer/train.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/transformer.py diff --git a/egs/fluent_speech_commands/SLU/README.md b/egs/fluent_speech_commands/SLU/README.md new file mode 100755 index 000000000..a203a9bfb --- /dev/null +++ b/egs/fluent_speech_commands/SLU/README.md @@ -0,0 +1,9 @@ +## Fluent Speech Commands recipe + +This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances. + +Dataset Paper link: + +cd icefall/egs/fluent_speech_commands/ +Training: python transducer/train.py +Decoding: python transducer/decode.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/local/compile_hlg.py b/egs/fluent_speech_commands/SLU/local/compile_hlg.py new file mode 100755 index 000000000..a7df8f966 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/local/compile_hlg.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G.fst.txt + +The generated HLG is saved in $lang_dir/HLG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") + H = k2.ctc_topo(max_token_id) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + logging.info("Loading G.fst.txt") + with open(lang_dir / "G.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") + + logging.info("Connecting LG") + HLG = k2.connect(HLG) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + logging.info(f"HLG.shape: {HLG.shape}") + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lang_dir) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py new file mode 100755 index 000000000..a51b7b47b --- /dev/null +++ b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 + +""" +This file computes fbank features of the Fluent Speech Commands dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or it wastes a +# lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_slu(manifest_dir, fbanks_dir): + src_dir = Path(manifest_dir) + output_dir = Path(fbanks_dir) + + # This dataset is rather small, so we use only one job + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 23 + + dataset_parts = ( + "train", + "valid", + "test", + ) + prefix = "slu" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}" + if cuts_file.is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 1, # use one job + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(cuts_file) + + +parser = argparse.ArgumentParser() +parser.add_argument("manifest_dir") +parser.add_argument("fbanks_dir") + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + args = parser.parse_args() + + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_slu(args.manifest_dir, args.fbanks_dir) diff --git a/egs/fluent_speech_commands/SLU/local/generate_lexicon.py b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py new file mode 100755 index 000000000..6263e062f --- /dev/null +++ b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py @@ -0,0 +1,59 @@ +import argparse + +import pandas +from tqdm import tqdm + + +def generate_lexicon(corpus_dir, lm_dir): + data = pandas.read_csv( + str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0 + ) + vocab_transcript = set() + vocab_frames = set() + transcripts = data["transcription"].tolist() + frames = list( + i + for i in zip( + data["action"].tolist(), data["object"].tolist(), data["location"].tolist() + ) + ) + + for transcript in tqdm(transcripts): + for word in transcript.split(): + vocab_transcript.add(word) + + for frame in tqdm(frames): + for word in frame: + vocab_frames.add("_".join(word.split())) + + with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file: + lexicon_transcript_file.write(" 1" + "\n") + lexicon_transcript_file.write(" 2" + "\n") + lexicon_transcript_file.write(" 0" + "\n") + id = 3 + for vocab in vocab_transcript: + lexicon_transcript_file.write(vocab + " " + str(id) + "\n") + id += 1 + + with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file: + lexicon_frames_file.write(" 1" + "\n") + lexicon_frames_file.write(" 2" + "\n") + lexicon_frames_file.write(" 0" + "\n") + id = 3 + for vocab in vocab_frames: + lexicon_frames_file.write(vocab + " " + str(id) + "\n") + id += 1 + + +parser = argparse.ArgumentParser() +parser.add_argument("corpus_dir") +parser.add_argument("lm_dir") + + +def main(): + args = parser.parse_args() + + generate_lexicon(args.corpus_dir, args.lm_dir) + + +main() diff --git a/egs/fluent_speech_commands/SLU/local/prepare_lang.py b/egs/fluent_speech_commands/SLU/local/prepare_lang.py new file mode 100755 index 000000000..2a71dcf81 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/local/prepare_lang.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import math +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch + +from icefall.lexicon import read_lexicon, write_lexicon + +Lexicon = List[Tuple[str, List[str]]] + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "!SIL", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + # assert token2id[""] == 0 + # assert word2id[""] == 0 + + eps = 0 + sil_token = word2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [word2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = word2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +parser = argparse.ArgumentParser() +parser.add_argument("lm_dir") + + +def main(): + args = parser.parse_args() + + out_dir = Path(args.lm_dir) + lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"] + names = ["frames", "transcript"] + sil_token = "!SIL" + sil_prob = 0.5 + + for name, lexicon_filename in zip(names, lexicon_filenames): + lexicon = read_lexicon(lexicon_filename) + tokens = get_words(lexicon) + words = get_words(lexicon) + new_lexicon = [] + for lexicon_item in lexicon: + new_lexicon.append((lexicon_item[0], [lexicon_item[0]])) + lexicon = new_lexicon + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") + + tokens = [""] + tokens + words = ["eps"] + words + ["#0", "!SIL"] + + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) + + write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id) + write_mapping(out_dir / ("words_" + name + ".txt"), word2id) + write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=word2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=word2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt")) + torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt")) + + if False: + # Just for debugging, will remove it + L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") + L_disambig.labels_sym = L.labels_sym + L_disambig.aux_labels_sym = L.aux_labels_sym + L.draw(out_dir / "L.png", title="L") + L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") + + +main() diff --git a/egs/fluent_speech_commands/SLU/prepare.sh b/egs/fluent_speech_commands/SLU/prepare.sh new file mode 100755 index 000000000..3ff339d91 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/prepare.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=1 +stop_stage=5 + +data_dir=path/to/fluent/speech/commands +target_root_dir=data/ + +lang_dir=${target_root_dir}/lang_phone +lm_dir=${target_root_dir}/lm +manifest_dir=${target_root_dir}/manifests +fbanks_dir=${target_root_dir}/fbanks + +. shared/parse_options.sh || exit 1 + +mkdir -p $lang_dir +mkdir -p $lm_dir + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "data_dir: $data_dir" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare slu manifest" + mkdir -p $manifest_dir + lhotse prepare slu $data_dir $manifest_dir +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute fbank for SLU" + mkdir -p $fbanks_dir + python ./local/compute_fbank_slu.py $manifest_dir $fbanks_dir +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare lang" + # NOTE: " SIL" is added for implementation convenience + # as the graph compiler code requires that there is a OOV word + # in the lexicon. + python ./local/generate_lexicon.py $data_dir $lm_dir +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Train LM" + # We use a unigram G + ./shared/make_kn_lm.py \ + -ngram-order 1 \ + -text $lm_dir/words_transcript.txt \ + -lm $lm_dir/G_transcript.arpa + + ./shared/make_kn_lm.py \ + -ngram-order 1 \ + -text $lm_dir/words_frames.txt \ + -lm $lm_dir/G_frames.arpa + + python ./local/prepare_lang.py $lm_dir + + if [ ! -f $lm_dir/G_transcript.fst.txt ]; then + python -m kaldilm \ + --read-symbol-table="$lm_dir/words_transcript.txt" \ + $lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt + fi + + if [ ! -f $lm_dir/G_frames.fst.txt ]; then + python -m kaldilm \ + --read-symbol-table="$lm_dir/words_frames.txt" \ + $lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt + fi + + mkdir -p $lm_dir/frames + mkdir -p $lm_dir/transcript + + chmod -R +777 . + + for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt; + do + j=${i//"_frames"/} + mv "$lm_dir/$i" $lm_dir/frames/$j + done + + for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt; + do + j=${i//"_transcript"/} + mv "$lm_dir/$i" $lm_dir/transcript/$j + done +fi + + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compile HLG" + ./local/compile_hlg.py --lang-dir $lm_dir/frames + ./local/compile_hlg.py --lang-dir $lm_dir/transcript + +fi diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared new file mode 120000 index 000000000..32a374a7f --- /dev/null +++ b/egs/fluent_speech_commands/SLU/shared @@ -0,0 +1 @@ +../../icefall/shared/ \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/__init__.py b/egs/fluent_speech_commands/SLU/transducer/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/egs/fluent_speech_commands/SLU/transducer/beam_search.py b/egs/fluent_speech_commands/SLU/transducer/beam_search.py new file mode 100755 index 000000000..a16aa0123 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/beam_search.py @@ -0,0 +1,71 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from transducer.model import Transducer + + +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, id2word: dict +) -> List[str]: + """ + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + device = model.device + + sos = torch.tensor([blank_id], device=device).reshape(1, 1) + decoder_out, (h, c) = model.decoder(sos) + T = encoder_out.size(1) + t = 0 + hyp = [] + max_u = 1000 # terminate after this number of steps + u = 0 + + while t < T and u < max_u: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # fmt: on + logits = model.joiner(current_encoder_out, decoder_out) + + log_prob = logits.log_softmax(dim=-1) + # log_prob is (N, 1, 1) + # TODO: Use logits.argmax() + y = log_prob.argmax() + if y != blank_id: + hyp.append(y.item()) + y = y.reshape(1, 1) + decoder_out, (h, c) = model.decoder(y, (h, c)) + u += 1 + else: + t += 1 + # id2word = {1: "YES", 2: "NO"} + + hyp = [id2word[i] for i in hyp] + + return hyp diff --git a/egs/fluent_speech_commands/SLU/transducer/conformer.py b/egs/fluent_speech_commands/SLU/transducer/conformer.py new file mode 120000 index 000000000..8be0dc864 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/decode.py b/egs/fluent_speech_commands/SLU/transducer/decode.py new file mode 100755 index 000000000..ba2b9aaea --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/decode.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from transducer.beam_search import greedy_search +from transducer.conformer import Conformer +from transducer.decoder import Decoder +from transducer.joiner import Joiner +from transducer.model import Transducer +from transducer.slu_datamodule import SluDataModule + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_id2word(params): + id2word = {} + + # 0 is blank + id = 1 + try: + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: + for line in lexicon_file: + if len(line.strip()) > 0: + id2word[id] = line.split()[0] + id += 1 + except: + pass + + return id2word + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=6, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + parser.add_argument( + "--exp-dir", + type=str, + default="transducer/exp", + help="Directory from which to load the checkpoints", + ) + parser.add_argument("--lang-dir", type=str, default="data/lm/frames") + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + "lang_dir": Path("data/lm/frames"), + # encoder/decoder params + "vocab_size": 3, # blank, yes, no + "blank_id": 0, + "embedding_dim": 32, + "hidden_dim": 16, + "num_decoder_layers": 4, + } + ) + + vocab_size = 1 + with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file: + for line in lexicon_file: + if ( + len(line.strip()) > 0 + ): # and '' not in line and '' not in line and '' not in line: + vocab_size += 1 + params.vocab_size = vocab_size + + return params + + +def decode_one_batch( + params: AttributeDict, model: nn.Module, batch: dict, id2word: dict +) -> List[List[int]]: + """Decode one batch and return the result in a list-of-list. + Each sub list contains the word IDs for an utterance in the batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding. + - params.method is "nbest", it uses nbest decoding. + + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py) + Returns: + Return the decoding result. `len(ans)` == batch size. + """ + device = model.device + feature = batch["inputs"] + feature = feature.to(device) + # at entry, feature is (N, T, C) + feature_lens = batch["supervisions"]["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + hyp = greedy_search(model=model, encoder_out=encoder_out_i, id2word=id2word) + hyps.append(hyp) + + # hyps = [[word_table[i] for i in ids] for ids in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> List[Tuple[List[int], List[int]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a tuple contains two elements (ref_text, hyp_text): + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + id2word = get_id2word(params) + + results = [] + for batch_idx, batch in enumerate(dl): + texts = [ + " ".join(a.supervisions[0].custom["frames"]) + for a in batch["supervisions"]["cut"] + ] + texts = [ + " " + a.replace("change language", "change_language") + " " + for a in texts + ] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps = decode_one_batch( + params=params, model=model, batch=batch, id2word=id2word + ) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + exp_dir: Path, + test_set_name: str, + results: List[Tuple[List[int], List[int]]], +) -> None: + """Save results to `exp_dir`. + Args: + exp_dir: + The output directory. This function create the following files inside + this directory: + + - recogs-{test_set_name}.text + + It contains the reference and hypothesis results, like below:: + + ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] + hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] + ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] + hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] + + - errs-{test_set_name}.txt + + It contains the detailed WER. + test_set_name: + The name of the test set, which will be part of the result filename. + results: + A list of tuples, each of which contains (ref_words, hyp_words). + Returns: + Return None. + """ + recog_path = exp_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = exp_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + write_error_stats(f, f"{test_set_name}", results) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + +def get_transducer_model(params: AttributeDict): + # encoder = Tdnn( + # num_features=params.feature_dim, + # output_dim=params.hidden_dim, + # ) + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.hidden_dim, + ) + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.hidden_dim, + embedding_dropout=0.4, + rnn_dropout=0.4, + ) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + return transducer + + +@torch.no_grad() +def main(): + parser = get_parser() + SluDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params["env_info"] = get_env_info() + + setup_logger(f"{params.exp_dir}/log/log-decode") + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = get_transducer_model(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + model.device = device + + # we need cut ids to display recognition results. + args.return_cuts = True + slu = SluDataModule(args) + test_dl = slu.test_dataloaders() + results = decode_dataset( + dl=test_dl, + params=params, + model=model, + ) + + test_set_name = str(args.feature_dir).split("/")[-2] + save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/fluent_speech_commands/SLU/transducer/decoder.py b/egs/fluent_speech_commands/SLU/transducer/decoder.py new file mode 120000 index 000000000..e99310f91 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/decoder.py @@ -0,0 +1 @@ +../../../yesno/ASR/transducer/decoder.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/joiner.py b/egs/fluent_speech_commands/SLU/transducer/joiner.py new file mode 120000 index 000000000..75fa64868 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/joiner.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/model.py b/egs/fluent_speech_commands/SLU/transducer/model.py new file mode 120000 index 000000000..10f6ddad1 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/model.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py new file mode 100755 index 000000000..fa715abdd --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py @@ -0,0 +1,289 @@ +# Copyright 2021 Piotr Żelasko +# 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import List + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool + + +class SluDataModule(DataModule): + """ + DataModule for k2 ASR experiments. + It assumes there is always one train dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + """ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--feature-dir", + type=Path, + default=Path("data/fbanks"), + help="Path to directory with train/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=30.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=False, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=10, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders(self) -> DataLoader: + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + + logging.info("About to create train dataset") + transforms = [] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + FbankConfig(sampling_rate=8000, num_mel_bins=23) + ), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + + return train_dl + + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get valid cuts") + cuts_valid = self.valid_cuts() + + logging.debug("About to create valid dataset") + valid = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create valid dataloader") + valid_dl = DataLoader( + valid, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + return valid_dl + + def test_dataloaders(self) -> DataLoader: + logging.info("About to get test cuts") + cuts_test = self.test_cuts() + + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts_test, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest_lazy( + self.args.feature_dir / "slu_cuts_train.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> List[CutSet]: + logging.info("About to get valid cuts") + cuts_valid = load_manifest_lazy( + self.args.feature_dir / "slu_cuts_valid.jsonl.gz" + ) + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + logging.info("About to get test cuts") + cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz") + return cuts_test diff --git a/egs/fluent_speech_commands/SLU/transducer/subsampling.py b/egs/fluent_speech_commands/SLU/transducer/subsampling.py new file mode 120000 index 000000000..fd7ca8b30 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/subsampling.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_conformer.py b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py new file mode 120000 index 000000000..3060dd70c --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/test_conformer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_decoder.py b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py new file mode 120000 index 000000000..d1bc718ce --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py @@ -0,0 +1 @@ +../../../yesno/ASR/transducer/test_decoder.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_joiner.py b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py new file mode 120000 index 000000000..248222a8a --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/test_joiner.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_transducer.py b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py new file mode 120000 index 000000000..df104bad7 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/test_transducer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/train.py b/egs/fluent_speech_commands/SLU/transducer/train.py new file mode 100755 index 000000000..a59c0b754 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/train.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import List, Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from lhotse.utils import fix_random_seed +from slu_datamodule import SluDataModule +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from transducer.conformer import Conformer + +# from torch.utils.tensorboard import SummaryWriter +from transducer.decoder import Decoder +from transducer.joiner import Joiner +from transducer.model import Transducer + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_word2id(params): + word2id = {} + + # 0 is blank + id = 1 + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: + for line in lexicon_file: + if len(line.strip()) > 0: + word2id[line.split()[0]] = id + id += 1 + + return word2id + + +def get_labels(texts: List[str], word2id) -> k2.RaggedTensor: + """ + Args: + texts: + A list of transcripts. + Returns: + Return a ragged tensor containing the corresponding word ID. + """ + # blank is 0 + word_ids = [] + for t in texts: + words = t.split() + ids = [word2id[w] for w in words] + word_ids.append(ids) + + return k2.RaggedTensor(word_ids) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=7, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + tdnn/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer/exp", + help="Directory to save results", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument("--lang-dir", type=str, default="data/lm/frames") + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - weight_decay: The weight_decay for the optimizer. + + - subsampling_factor: The subsampling factor for the model. + + - start_epoch: If it is not zero, load checkpoint `start_epoch-1` + and continue training from that checkpoint. + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + + """ + params = AttributeDict( + { + "lr": 1e-4, + "feature_dim": 23, + "weight_decay": 1e-6, + "start_epoch": 0, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 20, + "valid_interval": 3000, + "exp_dir": Path("transducer/exp"), + "lang_dir": Path("data/lm/frames"), + # encoder/decoder params + "vocab_size": 3, # blank, yes, no + "blank_id": 0, + "embedding_dim": 32, + "hidden_dim": 16, + "num_decoder_layers": 4, + } + ) + + vocab_size = 1 + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: + for line in lexicon_file: + if ( + len(line.strip()) > 0 + ): # and '' not in line and '' not in line and '' not in line: + vocab_size += 1 + params.vocab_size = vocab_size + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Tdnn in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + feature_lens = batch["supervisions"]["num_frames"].to(device) + + texts = [ + " ".join(a.supervisions[0].custom["frames"]) + for a in batch["supervisions"]["cut"] + ] + texts = [ + " " + a.replace("change language", "change_language") + " " + for a in texts + ] + labels = get_labels(texts, word2ids).to(device) + + with torch.set_grad_enabled(is_training): + loss = model(x=feature, x_lens=feature_lens, y=labels) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = feature.size(0) + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + word2ids, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=False, + word2ids=word2ids, + ) + assert loss.requires_grad is False + + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + word2ids, + tb_writer: None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, model=model, batch=batch, is_training=True, word2ids=word2ids + ) + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + word2ids=word2ids, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, + "train/valid_", + params.batch_idx_train, + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def get_transducer_model(params: AttributeDict): + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.hidden_dim, + ) + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.hidden_dim, + embedding_dropout=0.4, + rnn_dropout=0.4, + ) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + + return transducer + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + + params.update(vars(args)) + params["env_info"] = get_env_info() + + word2ids = get_word2id(params) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + # if args.tensorboard and rank == 0: + # tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + # else: + # tb_writer = None + tb_writer = None + + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + else: + device = torch.device("cpu") + logging.info(f"device: {device}") + + model = get_transducer_model(params) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + model.device = device + + optimizer = optim.Adam( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + slu = SluDataModule(args) + train_dl = slu.train_dataloaders() + + # There are only 60 waves: 30 files are used for training + # and the remaining 30 files are used for testing. + # We use test data as validation. + valid_dl = slu.test_dataloaders() + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + word2ids=word2ids, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + scheduler=None, + rank=rank, + ) + + logging.info("Done!") + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + SluDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/fluent_speech_commands/SLU/transducer/transformer.py b/egs/fluent_speech_commands/SLU/transducer/transformer.py new file mode 120000 index 000000000..214afed39 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/transformer.py \ No newline at end of file diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py index 231aca7f1..42ed44fdd 100755 --- a/icefall/shared/make_kn_lm.py +++ b/icefall/shared/make_kn_lm.py @@ -33,7 +33,7 @@ parser.add_argument( "-ngram-order", type=int, default=4, - choices=[2, 3, 4, 5, 6, 7], + choices=[1, 2, 3, 4, 5, 6, 7], help="Order of n-gram", ) parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") @@ -105,7 +105,7 @@ class NgramCounts: # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. def __init__(self, ngram_order, bos_symbol="", eos_symbol=""): - assert ngram_order >= 2 + assert ngram_order >= 1 self.ngram_order = ngram_order self.bos_symbol = bos_symbol @@ -169,7 +169,10 @@ class NgramCounts: with open(filename, encoding=default_encoding) as fp: for line in fp: line = line.strip(strip_chars) - self.add_raw_counts_from_line(line) + if self.ngram_order == 1: + self.add_raw_counts_from_line(line.split()[0]) + else: + self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: print( From b9e6327adfe0def4989ad508d0df8f6347da2c0f Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Sat, 3 Feb 2024 07:25:27 +0900 Subject: [PATCH 090/216] Fixing torch.ctc err (#1485) * fixing torch.ctc err * Move targets & lengths to CPU --- egs/librispeech/ASR/zipformer/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index f2f86af47..73009d35c 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -164,9 +164,9 @@ class AsrModel(nn.Module): ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets, - input_lengths=encoder_out_lens, - target_lengths=target_lengths, + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), reduction="sum", ) return ctc_loss From a813186f6463b4d3f6d73460dd1a57b2e975b803 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 5 Feb 2024 12:47:52 +0800 Subject: [PATCH 091/216] minor fix for docstr and default param. (#1490) * Update train.py and README.md --- README.md | 3 +++ egs/aishell/ASR/whisper/train.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index cc817702b..770066166 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,9 @@ The [LibriSpeech][librispeech] recipe supports the most comprehensive set of mod - LSTM-based Predictor - [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/) +#### Whisper + - [OpenAi Whisper](https://arxiv.org/abs/2212.04356) (We support fine-tuning on AiShell-1.) + If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details. We would like to highlight the performance of some of the recipes here. diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index d16793eb2..073b23713 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -19,7 +19,7 @@ Usage: #fine-tuning with deepspeed zero stage 1 -torchrun --nproc-per-node 8 ./whisper/train.py \ +torchrun --nproc_per_node 8 ./whisper/train.py \ --max-duration 200 \ --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ @@ -28,7 +28,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \ --deepspeed_config ./whisper/ds_config_zero1.json # fine-tuning with ddp -torchrun --nproc-per-node 8 ./whisper/train.py \ +torchrun --nproc_per_node 8 ./whisper/train.py \ --max-duration 200 \ --exp-dir whisper/exp_medium \ --manifest-dir data/fbank_whisper \ @@ -136,7 +136,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless7/exp", + default="whisper/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 777074046d6cb0f9b7ed7a98115299c04b8bcdf1 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 6 Feb 2024 18:25:43 +0800 Subject: [PATCH 092/216] Fine-tune recipe for Zipformer (#1484) 1. support finetune zipformer 2. update the usage; set a very large batch count --- .../local/compute_fbank_gigaspeech_splits.py | 18 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 17 +- .../ASR/zipformer/decode_gigaspeech.py | 1114 ++++++++++++ egs/librispeech/ASR/zipformer/finetune.py | 1521 +++++++++++++++++ 4 files changed, 2664 insertions(+), 6 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer/decode_gigaspeech.py create mode 100755 egs/librispeech/ASR/zipformer/finetune.py diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 1c71be0f9..176eb8a84 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -51,6 +51,14 @@ def get_parser(): "Determines batch size dynamically.", ) + parser.add_argument( + "--subset", + type=str, + default="XL", + choices=["XL", "L", "M", "S", "XS"], + help="Which subset to work with", + ) + parser.add_argument( "--num-splits", type=int, @@ -76,7 +84,7 @@ def get_parser(): def compute_fbank_gigaspeech_splits(args): num_splits = args.num_splits - output_dir = "data/fbank/XL_split" + output_dir = f"data/fbank/{args.subset}_split" output_dir = Path(output_dir) assert output_dir.exists(), f"{output_dir} does not exist!" @@ -96,15 +104,15 @@ def compute_fbank_gigaspeech_splits(args): logging.info(f"device: {device}") for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) + idx = f"{i}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") - cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz" + cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz" + raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz" logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -113,7 +121,7 @@ def compute_fbank_gigaspeech_splits(args): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{output_dir}/feats_XL_{idx}", + storage_path=f"{output_dir}/feats_{args.subset}_{idx}", num_workers=args.num_workers, batch_duration=args.batch_duration, overwrite=True, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index dd9e9ef1f..814390ad6 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -1,4 +1,4 @@ -# Copyright 2021 Piotr Żelasko +# Copyright 2021 Piotr Żelasko # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors @@ -475,3 +475,18 @@ class LibriSpeechAsrDataModule: return load_manifest_lazy( self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py new file mode 100755 index 000000000..3cda337c0 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py @@ -0,0 +1,1114 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts() + + dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts) + test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py new file mode 100755 index 000000000..843d103cc --- /dev/null +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -0,0 +1,1521 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Fine-tune without mux (i.e not mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 0 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +# Fine-tune without mux (i.e mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 1 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + 100000 + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + else: + # resuming training + assert params.start_epoch > 1, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts() + if params.use_mux: + librispeech_cuts = librispeech.train_all_shuf_cuts() + train_cuts = CutSet.mux( + gigaspeech_cuts, # num cuts = 688182 + librispeech_cuts, # num cuts = 843723 + weights=[688182, 843723], + stop_early=True, + ) + else: + train_cuts = gigaspeech_cuts + logging.info(train_cuts) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + + valid_sets = ["librispeech", "gigaspeech"] + valid_dls = [ + librispeech.valid_dataloaders(valid_cuts), + librispeech.valid_dataloaders(gigaspeech_dev_cuts), + ] + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 4ed88d948494d060a8c2827584d71bff616c3ca9 Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Wed, 7 Feb 2024 10:16:02 +0800 Subject: [PATCH 093/216] Update shared (#1487) There should be one more ../ --- egs/fluent_speech_commands/SLU/shared | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared index 32a374a7f..9115c7e17 120000 --- a/egs/fluent_speech_commands/SLU/shared +++ b/egs/fluent_speech_commands/SLU/shared @@ -1 +1 @@ -../../icefall/shared/ \ No newline at end of file +../../../icefall/shared/ From 711d6bc46297ba799174bae4f3ebb965224ad76f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Fri, 9 Feb 2024 10:44:19 +0800 Subject: [PATCH 094/216] Refactor prepare.sh in librispeech (#1493) * Refactor prepare.sh in librispeech, break it into three parts, prepare.sh (basic, minimal requirement for transducer), prepare_lm.sh (ngram & nnlm staff), prepare_mmi.sh (for MMI training). --- egs/librispeech/ASR/RESULTS.md | 2 +- egs/librispeech/ASR/generate-lm.sh | 20 -- egs/librispeech/ASR/local/train_bpe_model.py | 15 + egs/librispeech/ASR/prepare.sh | 334 ++++--------------- egs/librispeech/ASR/prepare_lm.sh | 262 +++++++++++++++ egs/librispeech/ASR/prepare_mmi.sh | 45 +++ 6 files changed, 393 insertions(+), 285 deletions(-) delete mode 100755 egs/librispeech/ASR/generate-lm.sh create mode 100755 egs/librispeech/ASR/prepare_lm.sh create mode 100755 egs/librispeech/ASR/prepare_mmi.sh diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ebf5e89c4..ee5422aba 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1526,7 +1526,7 @@ done You may also decode using LODR + LM shallow fusion. This decoding method is proposed in . It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be -generated by `generate-lm.sh`, or you may download it from . +generated by `prepare_lm.sh` at stage 4, or you may download it from . The decoding command is as follows: diff --git a/egs/librispeech/ASR/generate-lm.sh b/egs/librispeech/ASR/generate-lm.sh deleted file mode 100755 index dacd276d1..000000000 --- a/egs/librispeech/ASR/generate-lm.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env bash - -lang_dir=data/lang_bpe_500 - -for ngram in 2 3 4 5; do - if [ ! -f $lang_dir/${ngram}gram.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order ${ngram} \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/${ngram}gram.arpa - fi - - if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=${ngram} \ - $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt - fi -done diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 43142aee4..5979d5b98 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -28,6 +28,7 @@ import argparse import shutil from pathlib import Path +from typing import Dict import sentencepiece as spm @@ -57,6 +58,18 @@ def get_args(): return parser.parse_args() +def generate_tokens(lang_dir: Path): + """ + Generate the tokens.txt from a bpe model. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f: + for sym, i in token2id.items(): + f.write(f"{sym} {i}\n") + + def main(): args = get_args() vocab_size = args.vocab_size @@ -95,6 +108,8 @@ def main(): shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + generate_tokens(lang_dir) + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 4a5072cc0..40dc3260d 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -6,8 +6,21 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail nj=15 -stage=-1 -stop_stage=100 +# run step 0 to step 5 by default +stage=0 +stop_stage=5 + +# Note: This script just prepare the minimal requirements that needed by a +# transducer training with bpe units. +# +# If you want to use ngram or nnlm, please continue running prepare_lm.sh after +# you succeed running this script. +# +# This script also contains the steps to generate phone based units, but they +# will not run automatically, you can generate the phone based units by +# bash prepare.sh --stage -1 --stop-stage -1 +# bash prepare.sh --stage 6 --stop-stage 6 + # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -17,6 +30,18 @@ stop_stage=100 # You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. # You can download them from https://www.openslr.org/12 # +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +# +# lm directory is not necessary for transducer training with bpe units, but it +# is needed by phone based modeling, you can download it by running +# bash prepare.sh --stage -1 --stop-stage -1 +# then you can see the following files in the directory. # - $dl_dir/lm # This directory contains the following files downloaded from # http://www.openslr.org/resources/11 @@ -28,14 +53,7 @@ stop_stage=100 # - librispeech-vocab.txt # - librispeech-lexicon.txt # - librispeech-lm-norm.txt.gz -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech + dl_dir=$PWD/download . shared/parse_options.sh || exit 1 @@ -60,6 +78,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "Running prepare.sh" + log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then @@ -159,13 +179,49 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare phone based lang" + log "Stage 5: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + files=$( + find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare phone based lang" lang_dir=data/lang_phone mkdir -p $lang_dir - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/lm/librispeech-lexicon.txt | - sort | uniq > $lang_dir/lexicon.txt + if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then + log "No lexicon file in $dl_dir/lm, please run :" + log "prepare.sh --stage -1 --stop-stage -1" + exit -1 + fi + + if [ ! -f $lang_dir/lexicon.txt ]; then + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + fi if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang.py --lang-dir $lang_dir @@ -187,253 +243,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then $lang_dir/L_disambig.fst fi fi - - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/words.txt $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - files=$( - find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" - find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt" - find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare bigram token-level P for MMI training" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_dir/lexicon.txt \ - --transcript $lang_dir/transcript_words.txt \ - --oov "" \ - > $lang_dir/transcript_tokens.txt - fi - - if [ ! -f $lang_dir/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/P.arpa - fi - - if [ ! -f $lang_dir/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - $lang_dir/P.arpa > $lang_dir/P.fst.txt - fi - done -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt - fi - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/HL.fst ]; then - ./local/prepare_lang_fst.py \ - --lang-dir $lang_dir \ - --ngram-G ./data/lm/G_3_gram.fst.txt - fi - done -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir - done -fi - -# Compile LG for RNN-T fast_beam_search decoding -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Compile LG" - ./local/compile_lg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_lg.py --lang-dir $lang_dir - done -fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Generate LM training data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - lang_dir=data/lang_bpe_${vocab_size} - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ - --lm-archive $out_dir/lm_data.pt - done -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Generate LM validation data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/valid.txt ]; then - files=$( - find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt" - find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > $out_dir/valid.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/valid.txt \ - --lm-archive $out_dir/lm_data-valid.pt - done -fi - -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Generate LM test data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/test.txt ]; then - files=$( - find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt" - find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > $out_dir/test.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/test.txt \ - --lm-archive $out_dir/lm_data-test.pt - done -fi - -if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then - log "Stage 14: Sort LM training data" - # Sort LM training data by sentence length in descending order - # for ease of training. - # - # Sentence length equals to the number of BPE tokens - # in a sentence. - - for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data.pt \ - --out-lm-data $out_dir/sorted_lm_data.pt \ - --out-statistics $out_dir/statistics.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-valid.pt \ - --out-lm-data $out_dir/sorted_lm_data-valid.pt \ - --out-statistics $out_dir/statistics-valid.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-test.pt \ - --out-lm-data $out_dir/sorted_lm_data-test.pt \ - --out-statistics $out_dir/statistics-test.txt - done -fi diff --git a/egs/librispeech/ASR/prepare_lm.sh b/egs/librispeech/ASR/prepare_lm.sh new file mode 100755 index 000000000..a8eb5ca78 --- /dev/null +++ b/egs/librispeech/ASR/prepare_lm.sh @@ -0,0 +1,262 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +# This script generate Ngram LM / NNLM and related files that needed by decoding. + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# + +. prepare.sh --stage -1 --stop-stage 6 || exit 1 + +log "Running prepare_lm.sh" + +stage=0 +stop_stage=100 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare BPE based lexicon." + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare word level G" + # We assume you have installed kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_dir \ + --ngram-G ./data/lm/G_3_gram.fst.txt + fi + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare token level ngram G" + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + for ngram in 2 3 4 5; do + if [ ! -f $lang_dir/${ngram}gram.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/${ngram}gram.arpa + fi + + if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt + fi + done + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate NNLM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate NNLM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + files=$( + find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Generate NNLM test data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/test.txt ]; then + files=$( + find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/test.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Sort NNLM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/librispeech/ASR/prepare_mmi.sh b/egs/librispeech/ASR/prepare_mmi.sh new file mode 100755 index 000000000..d8a6e0caf --- /dev/null +++ b/egs/librispeech/ASR/prepare_mmi.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + + +. prepare.sh --stage -1 --stop-stage 6 || exit 1 + +log "Running prepare_mmi.sh" + +stage=0 +stop_stage=100 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare bigram token-level P for MMI training" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa > $lang_dir/P.fst.txt + fi + done +fi From d9ae8c02a0abdeddc5a4cf9fad72293eda134de3 Mon Sep 17 00:00:00 2001 From: safarisadegh <37085305+safarisadegh@users.noreply.github.com> Date: Fri, 9 Feb 2024 10:35:01 +0330 Subject: [PATCH 095/216] Update README.md (#1497) --- icefall/ctc/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/ctc/README.md b/icefall/ctc/README.md index 0096bc096..1e342f6a3 100644 --- a/icefall/ctc/README.md +++ b/icefall/ctc/README.md @@ -12,6 +12,6 @@ pip install kaldifst kaldi-decoder ``` to install the dependencies. -[kaldi-decoder]: https://github.com/i2-fsa/kaldi-decoder +[kaldi-decoder]: https://github.com/k2-fsa/kaldi-decoder [kaldifst]: https://github.com/k2-fsa/kaldifst [k2]: https://github.com/k2-fsa/k2 From 06b356a610ec0bcef8982f011ba2a46cd8ca29b5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 18 Feb 2024 12:05:38 +0800 Subject: [PATCH 096/216] Update cpu docker images to support torch 2.2.0 (#1499) --- .github/scripts/docker/Dockerfile | 1 + .../scripts/docker/generate_build_matrix.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index f6a088af1..ee0099911 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -11,6 +11,7 @@ ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}" RUN apt-get update -y && \ apt-get install -qq -y \ + cmake \ ffmpeg \ git \ git-lfs \ diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index bdde97647..f0690f8bf 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -6,8 +6,8 @@ import json def version_gt(a, b): - a_major, a_minor = a.split(".")[:2] - b_major, b_minor = b.split(".")[:2] + a_major, a_minor = list(map(int, a.split(".")))[:2] + b_major, b_minor = list(map(int, b.split(".")))[:2] if a_major > b_major: return True @@ -18,8 +18,8 @@ def version_gt(a, b): def version_ge(a, b): - a_major, a_minor = a.split(".")[:2] - b_major, b_minor = b.split(".")[:2] + a_major, a_minor = list(map(int, a.split(".")))[:2] + b_major, b_minor = list(map(int, b.split(".")))[:2] if a_major > b_major: return True @@ -43,11 +43,12 @@ def get_torchaudio_version(torch_version): def get_matrix(): - k2_version = "1.24.4.dev20231220" - kaldifeat_version = "1.25.3.dev20231221" - version = "1.2" - python_version = ["3.8", "3.9", "3.10", "3.11"] + k2_version = "1.24.4.dev20240211" + kaldifeat_version = "1.25.4.dev20240210" + version = "1.3" + python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + torch_version += ["2.2.0"] matrix = [] for p in python_version: @@ -57,6 +58,10 @@ def get_matrix(): if version_gt(p, "3.10") and not version_gt(t, "2.0"): continue + # only torch>=2.2.0 supports python 3.12 + if version_gt(p, "3.11") and not version_gt(t, "2.1"): + continue + matrix.append( { "k2-version": k2_version, From 17688476e5cbdba92c682d3a75e3941b647573a7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 18 Feb 2024 14:56:04 +0800 Subject: [PATCH 097/216] Provider docker images for torch 2.2.0 (#1501) --- .github/workflows/build-docker-image.yml | 2 +- .github/workflows/run-docker-image.yml | 9 ++- docker/torch1.12.1-cuda11.3.dockerfile | 4 +- docker/torch1.13.0-cuda11.6.dockerfile | 4 +- docker/torch1.9.0-cuda10.2.dockerfile | 4 +- docker/torch2.0.0-cuda11.7.dockerfile | 4 +- docker/torch2.1.0-cuda11.8.dockerfile | 4 +- docker/torch2.1.0-cuda12.1.dockerfile | 4 +- docker/torch2.2.0-cuda11.8.dockerfile | 70 ++++++++++++++++++++++++ docker/torch2.2.0-cuda12.1.dockerfile | 70 ++++++++++++++++++++++++ docs/source/docker/intro.rst | 2 + 11 files changed, 163 insertions(+), 14 deletions(-) create mode 100644 docker/torch2.2.0-cuda11.8.dockerfile create mode 100644 docker/torch2.2.0-cuda12.1.dockerfile diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index e5d96dcdf..d5081f7d8 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index d048923b6..65ba2cd64 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -14,13 +14,20 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v2 with: fetch-depth: 0 + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + - name: Run the build process with Docker uses: addnab/docker-run-action@v3 with: diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index deb5715cc..cb885e59e 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.7 -ARG K2_VERSION="1.24.4.dev20230725+cuda11.3.torch1.12.1" -ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.3.torch1.12.1" +ARG K2_VERSION="1.24.4.dev20240211+cuda11.3.torch1.12.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.3.torch1.12.1" ARG TORCHAUDIO_VERSION="0.12.1+cu113" LABEL authors="Fangjun Kuang " diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index afc6c1b84..e238d87aa 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.9 -ARG K2_VERSION="1.24.4.dev20231021+cuda11.6.torch1.13.0" -ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.6.torch1.13.0" +ARG K2_VERSION="1.24.4.dev20240211+cuda11.6.torch1.13.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.6.torch1.13.0" ARG TORCHAUDIO_VERSION="0.13.0+cu116" LABEL authors="Fangjun Kuang " diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index 9ff225b54..26d45cafc 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.7 -ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0" -ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda10.2.torch1.9.0" +ARG K2_VERSION="1.24.4.dev20240211+cuda10.2.torch1.9.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda10.2.torch1.9.0" ARG TORCHAUDIO_VERSION="0.9.0" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index db8076560..02906e53b 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.10 -ARG K2_VERSION="1.24.4.dev20231021+cuda11.7.torch2.0.0" -ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.7.torch2.0.0" +ARG K2_VERSION="1.24.4.dev20240211+cuda11.7.torch2.0.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.7.torch2.0.0" ARG TORCHAUDIO_VERSION="2.0.0+cu117" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index b006b0d96..c87305922 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.10 -ARG K2_VERSION="1.24.4.dev20231021+cuda11.8.torch2.1.0" -ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.8.torch2.1.0" +ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.1.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.1.0" ARG TORCHAUDIO_VERSION="2.1.0+cu118" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index 1b078dc22..f4c297678 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.10 -ARG K2_VERSION="1.24.4.dev20231021+cuda12.1.torch2.1.0" -ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda12.1.torch2.1.0" +ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.1.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.1.0" ARG TORCHAUDIO_VERSION="2.1.0+cu121" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.2.0-cuda11.8.dockerfile b/docker/torch2.2.0-cuda11.8.dockerfile new file mode 100644 index 000000000..c59661c27 --- /dev/null +++ b/docker/torch2.2.0-cuda11.8.dockerfile @@ -0,0 +1,70 @@ +FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.2.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.2.0" +ARG TORCHAUDIO_VERSION="2.2.0+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.2.0-cuda12.1.dockerfile b/docker/torch2.2.0-cuda12.1.dockerfile new file mode 100644 index 000000000..2c484efd5 --- /dev/null +++ b/docker/torch2.2.0-cuda12.1.dockerfile @@ -0,0 +1,70 @@ +FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.2.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.2.0" +ARG TORCHAUDIO_VERSION="2.2.0+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index cbd300d9b..149970eff 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -34,6 +34,8 @@ which will give you something like below: .. code-block:: bash + "torch2.2.0-cuda12.1" + "torch2.2.0-cuda11.8" "torch2.1.0-cuda12.1" "torch2.1.0-cuda11.8" "torch2.0.0-cuda11.7" From 7eb360d0d5f3eb03292d3ff4596a8d50c9765888 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 18 Feb 2024 20:32:40 +0800 Subject: [PATCH 098/216] Fix cpu docker images for torch 2.2.0 (#1502) --- .github/scripts/docker/generate_build_matrix.py | 4 ++-- .github/workflows/yesno.yml | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index f0690f8bf..425afac2b 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -43,8 +43,8 @@ def get_torchaudio_version(torch_version): def get_matrix(): - k2_version = "1.24.4.dev20240211" - kaldifeat_version = "1.25.4.dev20240210" + k2_version = "1.24.4.dev20240218" + kaldifeat_version = "1.25.4.dev20240218" version = "1.3" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] diff --git a/.github/workflows/yesno.yml b/.github/workflows/yesno.yml index 182300dfa..de822b33f 100644 --- a/.github/workflows/yesno.yml +++ b/.github/workflows/yesno.yml @@ -59,4 +59,7 @@ jobs: cd /icefall git config --global --add safe.directory /icefall + python3 -m torch.utils.collect_env + python3 -m k2.version + .github/scripts/yesno/ASR/run.sh From db4d66c0e39a06464f5c316c727ed76babeb10eb Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 19 Feb 2024 16:13:09 +0800 Subject: [PATCH 099/216] Fixed softlink for `ljspeech` recipe (#1503) --- egs/ljspeech/TTS/shared | 1 + egs/ljspeech/TTS/shared/parse_options.sh | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) create mode 120000 egs/ljspeech/TTS/shared delete mode 120000 egs/ljspeech/TTS/shared/parse_options.sh diff --git a/egs/ljspeech/TTS/shared b/egs/ljspeech/TTS/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/ljspeech/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh deleted file mode 120000 index e4665e7de..000000000 --- a/egs/ljspeech/TTS/shared/parse_options.sh +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file From b3e2044068001a24bc9293f5b7063377173631d3 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 19 Feb 2024 19:33:32 +0800 Subject: [PATCH 100/216] minor fix of vits/tokenizer.py (#1504) * minor fix of vits/tokenizer.py --- egs/ljspeech/TTS/vits/tokenizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 70f1240b4..b0afc6a04 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -74,7 +74,7 @@ class Tokenizer(object): if intersperse_blank: token_ids = intersperse(token_ids, self.blank_id) - token_ids_list.append(token_ids) + token_ids_list.append(token_ids) return token_ids_list @@ -103,6 +103,7 @@ class Tokenizer(object): if intersperse_blank: token_ids = intersperse(token_ids, self.blank_id) - token_ids_list.append(token_ids) + + token_ids_list.append(token_ids) return token_ids_list From e59fa38e86bd05241daa4217d66eaa0e36825547 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Tue, 20 Feb 2024 03:40:15 +0100 Subject: [PATCH 101/216] docs: minor fixes of LM rescoring texts (#1498) --- .../decoding-with-langugage-models/LODR.rst | 6 ++--- .../shallow-fusion.rst | 24 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst index b6b6e8cbb..d4b6f7065 100644 --- a/docs/source/decoding-with-langugage-models/LODR.rst +++ b/docs/source/decoding-with-langugage-models/LODR.rst @@ -30,7 +30,7 @@ of langugae model integration. First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here `_ to address the language information mismatch between the training corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain -are acoustically similar, DR derives the following formular for decoding with Bayes' theorem: +are acoustically similar, DR derives the following formula for decoding with Bayes' theorem: .. math:: @@ -41,7 +41,7 @@ are acoustically similar, DR derives the following formular for decoding with Ba where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively. -Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to +Here, the source domain LM is trained on the training corpus. The only difference in the above formula compared to shallow fusion is the subtraction of the source domain LM. Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is @@ -58,7 +58,7 @@ during decoding for transducer model: In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR, the only difference lies in the choice of source domain LM. According to the original `paper `_, -LODR achieves similar performance compared DR in both intra-domain and cross-domain settings. +LODR achieves similar performance compared to DR in both intra-domain and cross-domain settings. As a bi-gram is much faster to evaluate, LODR is usually much faster. Now, we will show you how to use LODR in ``icefall``. diff --git a/docs/source/decoding-with-langugage-models/shallow-fusion.rst b/docs/source/decoding-with-langugage-models/shallow-fusion.rst index 684fefeb4..8b2586730 100644 --- a/docs/source/decoding-with-langugage-models/shallow-fusion.rst +++ b/docs/source/decoding-with-langugage-models/shallow-fusion.rst @@ -9,9 +9,9 @@ to improve the word-error-rate of a transducer model. .. note:: - This tutorial is based on the recipe + This tutorial is based on the recipe `pruned_transducer_stateless7_streaming `_, - which is a streaming transducer model trained on `LibriSpeech`_. + which is a streaming transducer model trained on `LibriSpeech`_. However, you can easily apply shallow fusion to other recipes. If you encounter any problems, please open an issue here `icefall `_. @@ -69,11 +69,11 @@ Training a language model usually takes a long time, we can download a pre-train .. code-block:: bash $ # download the external LM - $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm $ # create a symbolic link so that the checkpoint can be loaded $ pushd icefall-librispeech-rnn-lm/exp $ git lfs pull --include "pretrained.pt" - $ ln -s pretrained.pt epoch-99.pt + $ ln -s pretrained.pt epoch-99.pt $ popd .. note:: @@ -85,7 +85,7 @@ Training a language model usually takes a long time, we can download a pre-train To use shallow fusion for decoding, we can execute the following command: .. code-block:: bash - + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ lm_dir=./icefall-librispeech-rnn-lm/exp $ lm_scale=0.29 @@ -133,16 +133,16 @@ The decoding result obtained with the above command are shown below. $ For test-other, WER of different settings are: $ beam_size_4 7.08 best for test-other -The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%. +The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%. A few parameters can be tuned to further boost the performance of shallow fusion: -- ``--lm-scale`` +- ``--lm-scale`` - Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large, - the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3. + Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large, + the LM score might be dominant during decoding, leading to bad WER. A typical value of this is around 0.3. + +- ``--beam-size`` -- ``--beam-size`` - The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy. Here, we also show how `--beam-size` effect the WER and decoding time: @@ -176,4 +176,4 @@ As we see, a larger beam size during shallow fusion improves the WER, but is als - + From 027302c902ce9ab44754d42a56cf1eba9a075be9 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 20 Feb 2024 14:38:51 +0800 Subject: [PATCH 102/216] minor fix for param. names (#1495) --- icefall/lm_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py index 5e2783a47..26839c61c 100644 --- a/icefall/lm_wrapper.py +++ b/icefall/lm_wrapper.py @@ -159,7 +159,7 @@ class LmScorer(torch.nn.Module): """ if lm_type == "rnn": model = RnnLmModel( - vocab_size=params.vocab_size, + vocab_size=params.lm_vocab_size, embedding_dim=params.rnn_lm_embedding_dim, hidden_dim=params.rnn_lm_hidden_dim, num_layers=params.rnn_lm_num_layers, @@ -183,7 +183,7 @@ class LmScorer(torch.nn.Module): elif lm_type == "transformer": model = TransformerLM( - vocab_size=params.vocab_size, + vocab_size=params.lm_vocab_size, d_model=params.transformer_lm_encoder_dim, embedding_dim=params.transformer_lm_embedding_dim, dim_feedforward=params.transformer_lm_dim_feedforward, From c19b4147789f306efed754dd0ed8f651017a7484 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Wed, 21 Feb 2024 08:04:16 +0800 Subject: [PATCH 103/216] Update docker (adding pypinyin (#1513) Update docker (adding pypinyin) --- .github/scripts/docker/Dockerfile | 1 + .github/scripts/docker/generate_build_matrix.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index ee0099911..4adb7ab5c 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -51,6 +51,7 @@ RUN pip install --no-cache-dir \ onnxruntime \ pytest \ sentencepiece>=0.1.96 \ + pypinyin==0.50.0 \ six \ tensorboard \ typeguard diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 425afac2b..ed01bd740 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20240218" kaldifeat_version = "1.25.4.dev20240218" - version = "1.3" + version = "1.4" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] torch_version += ["2.2.0"] From 13daf73468da70f5db49c928969fce6c6edc041a Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 21 Feb 2024 18:06:27 +0800 Subject: [PATCH 104/216] docs for finetune zipformer (#1509) --- .../from_supervised/finetune_zipformer.rst | 140 ++++++++++++++++++ docs/source/recipes/Finetune/index.rst | 15 ++ docs/source/recipes/index.rst | 1 + 3 files changed, 156 insertions(+) create mode 100644 docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst create mode 100644 docs/source/recipes/Finetune/index.rst diff --git a/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst b/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst new file mode 100644 index 000000000..7ca4eb811 --- /dev/null +++ b/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst @@ -0,0 +1,140 @@ +Finetune from a supervised pre-trained Zipformer model +====================================================== + +This tutorial shows you how to fine-tune a supervised pre-trained **Zipformer** +transducer model on a new dataset. + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe + + +For illustration purpose, we fine-tune the Zipformer transducer model +pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your +own data for fine-tuning if you create a manifest for your new dataset. + +Data preparation +---------------- + +Please follow the instructions in the `GigaSpeech recipe `_ +to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial. + + +Model preparation +----------------- + +We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The +checkpoint of the model can be downloaded via the following command: + +.. code-block:: bash + + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + $ cd icefall-asr-librispeech-zipformer-2023-05-15/exp + $ git lfs pull --include "pretrained.pt" + $ ln -s pretrained.pt epoch-99.pt + $ cd ../data/lang_bpe_500 + $ git lfs pull --include bpe.model + $ cd ../../.. + +Before fine-tuning, let's test the model's WER on the new domain. The following command performs +decoding on the GigaSpeech test sets: + +.. code-block:: bash + + ./zipformer/decode_gigaspeech.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \ + --use-averaged-model 0 \ + --max-duration 1000 \ + --decoding-method greedy_search + +You should see the following numbers: + +.. code-block:: + + For dev, WER of different settings are: + greedy_search 20.06 best for dev + + For test, WER of different settings are: + greedy_search 19.27 best for test + + +Fine-tune +--------- + +Since LibriSpeech and GigaSpeech are both English dataset, we can initialize the whole +Zipformer model with the checkpoint downloaded in the previous step (otherwise we should consider +initializing the stateless decoder and joiner from scratch due to the mismatch of the output +vocabulary). The following command starts a fine-tuning experiment: + +.. code-block:: bash + + $ use_mux=0 + $ do_finetune=1 + + $ ./zipformer/finetune.py \ + --world-size 2 \ + --num-epochs 20 \ + --start-epoch 1 \ + --exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \ + --use-fp16 1 \ + --base-lr 0.0045 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --do-finetune $do_finetune \ + --use-mux $use_mux \ + --master-port 13024 \ + --finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \ + --max-duration 1000 + +The following arguments are related to fine-tuning: + +- ``--base-lr`` + The learning rate used for fine-tuning. We suggest to set a **small** learning rate for fine-tuning, + otherwise the model may forget the initialization very quickly. A reasonable value should be around + 1/10 of the original lr, i.e 0.0045. + +- ``--do-finetune`` + If True, do fine-tuning by initializing the model from a pre-trained checkpoint. + **Note that if you want to resume your fine-tuning experiment from certain epochs, you + need to set this to False.** + +- ``--finetune-ckpt`` + The path to the pre-trained checkpoint (used for initialization). + +- ``--use-mux`` + If True, mix the fine-tune data with the original training data by using `CutSet.mux `_ + This helps maintain the model's performance on the original domain if the original training + is available. **If you don't have the original training data, please set it to False.** + +After fine-tuning, let's test the WERs. You can do this via the following command: + +.. code-block:: bash + + $ use_mux=0 + $ do_finetune=1 + $ ./zipformer/decode_gigaspeech.py \ + --epoch 20 \ + --avg 10 \ + --exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \ + --use-averaged-model 1 \ + --max-duration 1000 \ + --decoding-method greedy_search + +You should see numbers similar to the ones below: + +.. code-block:: text + + For dev, WER of different settings are: + greedy_search 13.47 best for dev + + For test, WER of different settings are: + greedy_search 13.66 best for test + +Compared to the original checkpoint, the fine-tuned model achieves much lower WERs +on the GigaSpeech test sets. diff --git a/docs/source/recipes/Finetune/index.rst b/docs/source/recipes/Finetune/index.rst new file mode 100644 index 000000000..e62b8980f --- /dev/null +++ b/docs/source/recipes/Finetune/index.rst @@ -0,0 +1,15 @@ +Fine-tune a pre-trained model +============================= + +After pre-training on public available datasets, the ASR model is already capable of +performing general speech recognition with relatively high accuracy. However, the accuracy +could be still low on certain domains that are quite different from the original training +set. In this case, we can fine-tune the model with a small amount of additional labelled +data to improve the performance on new domains. + + +.. toctree:: + :maxdepth: 2 + :caption: Table of Contents + + from_supervised/finetune_zipformer diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 8df61f0d0..52795d452 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -17,3 +17,4 @@ We may add recipes for other tasks as well in the future. Streaming-ASR/index RNN-LM/index TTS/index + Finetune/index From aac7df064a6d1529f3bf4acccc6c550bd260b7b3 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 22 Feb 2024 15:31:20 +0800 Subject: [PATCH 105/216] Recipes for open vocabulary keyword spotting (#1428) * English recipe on gigaspeech; Chinese recipe on wenetspeech --- ...ev_test.py => compute_fbank_gigaspeech.py} | 14 +- .../local/compute_fbank_gigaspeech_splits.py | 16 +- .../ASR/local/preprocess_gigaspeech.py | 38 +- egs/gigaspeech/ASR/prepare.sh | 61 +- .../ASR/zipformer/asr_datamodule.py | 5 +- egs/gigaspeech/ASR/zipformer/train.py | 21 +- egs/gigaspeech/KWS/RESULTS.md | 49 + egs/gigaspeech/KWS/prepare.sh | 85 + egs/gigaspeech/KWS/run.sh | 197 +++ egs/gigaspeech/KWS/shared | 1 + .../KWS/zipformer/asr_datamodule.py | 477 ++++++ egs/gigaspeech/KWS/zipformer/beam_search.py | 1 + egs/gigaspeech/KWS/zipformer/decode-asr.py | 1066 +++++++++++++ egs/gigaspeech/KWS/zipformer/decode.py | 689 ++++++++ egs/gigaspeech/KWS/zipformer/decoder.py | 1 + .../KWS/zipformer/encoder_interface.py | 1 + .../KWS/zipformer/export-onnx-streaming.py | 1 + egs/gigaspeech/KWS/zipformer/export.py | 1 + egs/gigaspeech/KWS/zipformer/finetune.py | 644 ++++++++ .../KWS/zipformer/gigaspeech_scoring.py | 1 + egs/gigaspeech/KWS/zipformer/joiner.py | 1 + egs/gigaspeech/KWS/zipformer/model.py | 1 + egs/gigaspeech/KWS/zipformer/optim.py | 1 + egs/gigaspeech/KWS/zipformer/scaling.py | 1 + egs/gigaspeech/KWS/zipformer/subsampling.py | 1 + egs/gigaspeech/KWS/zipformer/train.py | 1367 ++++++++++++++++ egs/gigaspeech/KWS/zipformer/zipformer.py | 1 + .../beam_search.py | 241 +++ .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 10 +- .../ASR/pruned_transducer_stateless5/train.py | 8 +- .../local/prepare_dataset_from_kaldi_dir.py | 142 ++ egs/wenetspeech/ASR/local/prepare_pinyin.py | 275 ++++ egs/wenetspeech/ASR/prepare.sh | 16 +- egs/wenetspeech/KWS/RESULTS.md | 58 + egs/wenetspeech/KWS/prepare.sh | 90 ++ egs/wenetspeech/KWS/run.sh | 201 +++ egs/wenetspeech/KWS/shared | 1 + .../KWS/zipformer/asr_datamodule.py | 459 ++++++ egs/wenetspeech/KWS/zipformer/beam_search.py | 1 + egs/wenetspeech/KWS/zipformer/decode-asr.py | 767 +++++++++ egs/wenetspeech/KWS/zipformer/decode.py | 737 +++++++++ egs/wenetspeech/KWS/zipformer/decoder.py | 1 + .../KWS/zipformer/encoder_interface.py | 1 + .../KWS/zipformer/export-onnx-streaming.py | 1 + egs/wenetspeech/KWS/zipformer/export.py | 1 + egs/wenetspeech/KWS/zipformer/finetune.py | 814 ++++++++++ egs/wenetspeech/KWS/zipformer/joiner.py | 1 + egs/wenetspeech/KWS/zipformer/model.py | 1 + egs/wenetspeech/KWS/zipformer/optim.py | 1 + egs/wenetspeech/KWS/zipformer/scaling.py | 1 + .../KWS/zipformer/scaling_converter.py | 1 + egs/wenetspeech/KWS/zipformer/subsampling.py | 1 + egs/wenetspeech/KWS/zipformer/train.py | 1401 +++++++++++++++++ egs/wenetspeech/KWS/zipformer/zipformer.py | 1 + icefall/char_graph_compiler.py | 35 +- icefall/context_graph.py | 218 ++- icefall/utils.py | 96 ++ 57 files changed, 10203 insertions(+), 120 deletions(-) rename egs/gigaspeech/ASR/local/{compute_fbank_gigaspeech_dev_test.py => compute_fbank_gigaspeech.py} (87%) create mode 100644 egs/gigaspeech/KWS/RESULTS.md create mode 100755 egs/gigaspeech/KWS/prepare.sh create mode 100755 egs/gigaspeech/KWS/run.sh create mode 120000 egs/gigaspeech/KWS/shared create mode 100644 egs/gigaspeech/KWS/zipformer/asr_datamodule.py create mode 120000 egs/gigaspeech/KWS/zipformer/beam_search.py create mode 100755 egs/gigaspeech/KWS/zipformer/decode-asr.py create mode 100755 egs/gigaspeech/KWS/zipformer/decode.py create mode 120000 egs/gigaspeech/KWS/zipformer/decoder.py create mode 120000 egs/gigaspeech/KWS/zipformer/encoder_interface.py create mode 120000 egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py create mode 120000 egs/gigaspeech/KWS/zipformer/export.py create mode 100755 egs/gigaspeech/KWS/zipformer/finetune.py create mode 120000 egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py create mode 120000 egs/gigaspeech/KWS/zipformer/joiner.py create mode 120000 egs/gigaspeech/KWS/zipformer/model.py create mode 120000 egs/gigaspeech/KWS/zipformer/optim.py create mode 120000 egs/gigaspeech/KWS/zipformer/scaling.py create mode 120000 egs/gigaspeech/KWS/zipformer/subsampling.py create mode 100755 egs/gigaspeech/KWS/zipformer/train.py create mode 120000 egs/gigaspeech/KWS/zipformer/zipformer.py create mode 100644 egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py create mode 100755 egs/wenetspeech/ASR/local/prepare_pinyin.py create mode 100644 egs/wenetspeech/KWS/RESULTS.md create mode 100755 egs/wenetspeech/KWS/prepare.sh create mode 100755 egs/wenetspeech/KWS/run.sh create mode 120000 egs/wenetspeech/KWS/shared create mode 100644 egs/wenetspeech/KWS/zipformer/asr_datamodule.py create mode 120000 egs/wenetspeech/KWS/zipformer/beam_search.py create mode 100755 egs/wenetspeech/KWS/zipformer/decode-asr.py create mode 100755 egs/wenetspeech/KWS/zipformer/decode.py create mode 120000 egs/wenetspeech/KWS/zipformer/decoder.py create mode 120000 egs/wenetspeech/KWS/zipformer/encoder_interface.py create mode 120000 egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py create mode 120000 egs/wenetspeech/KWS/zipformer/export.py create mode 100755 egs/wenetspeech/KWS/zipformer/finetune.py create mode 120000 egs/wenetspeech/KWS/zipformer/joiner.py create mode 120000 egs/wenetspeech/KWS/zipformer/model.py create mode 120000 egs/wenetspeech/KWS/zipformer/optim.py create mode 120000 egs/wenetspeech/KWS/zipformer/scaling.py create mode 120000 egs/wenetspeech/KWS/zipformer/scaling_converter.py create mode 120000 egs/wenetspeech/KWS/zipformer/subsampling.py create mode 100755 egs/wenetspeech/KWS/zipformer/train.py create mode 120000 egs/wenetspeech/KWS/zipformer/zipformer.py diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py similarity index 87% rename from egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py rename to egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py index 07beeb1f0..9e0df0989 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py @@ -30,15 +30,15 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_gigaspeech_dev_test(): +def compute_fbank_gigaspeech(): in_out_dir = Path("data/fbank") # number of workers in dataloader num_workers = 20 # number of seconds in a batch - batch_duration = 600 + batch_duration = 1000 - subsets = ("DEV", "TEST") + subsets = ("L", "M", "S", "XS", "DEV", "TEST") device = torch.device("cpu") if torch.cuda.is_available(): @@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test(): logging.info(f"device: {device}") for partition in subsets: - cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz" + cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz" + raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test(): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{in_out_dir}/feats_{partition}", + storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}", num_workers=num_workers, batch_duration=batch_duration, overwrite=True, @@ -80,7 +80,7 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_gigaspeech_dev_test() + compute_fbank_gigaspeech() if __name__ == "__main__": diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 176eb8a84..51cd59078 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -51,14 +51,6 @@ def get_parser(): "Determines batch size dynamically.", ) - parser.add_argument( - "--subset", - type=str, - default="XL", - choices=["XL", "L", "M", "S", "XS"], - help="Which subset to work with", - ) - parser.add_argument( "--num-splits", type=int, @@ -84,7 +76,7 @@ def get_parser(): def compute_fbank_gigaspeech_splits(args): num_splits = args.num_splits - output_dir = f"data/fbank/{args.subset}_split" + output_dir = f"data/fbank/XL_split" output_dir = Path(output_dir) assert output_dir.exists(), f"{output_dir} does not exist!" @@ -107,12 +99,12 @@ def compute_fbank_gigaspeech_splits(args): idx = f"{i}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") - cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz" + cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz" + raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz" logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -121,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{output_dir}/feats_{args.subset}_{idx}", + storage_path=f"{output_dir}/gigaspeech_feats_{idx}", num_workers=args.num_workers, batch_duration=args.batch_duration, overwrite=True, diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 31abe7fff..b6603f80d 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -16,17 +16,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging import re from pathlib import Path from lhotse import CutSet, SupervisionSegment from lhotse.recipes.utils import read_manifests_if_cached +from icefall.utils import str2bool # Similar text filtering and normalization procedure as in: # https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Whether to use speed perturbation.", + ) + + return parser.parse_args() + + def normalize_text( utt: str, punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), @@ -42,7 +56,7 @@ def has_no_oov( return oov_pattern.search(sup.text) is None -def preprocess_giga_speech(): +def preprocess_giga_speech(args): src_dir = Path("data/manifests") output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) @@ -51,6 +65,10 @@ def preprocess_giga_speech(): "DEV", "TEST", "XL", + "L", + "M", + "S", + "XS", ) logging.info("Loading manifest (may take 4 minutes)") @@ -71,7 +89,7 @@ def preprocess_giga_speech(): for partition, m in manifests.items(): logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" + raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" if raw_cuts_path.is_file(): logging.info(f"{partition} already exists - skipping") continue @@ -94,11 +112,14 @@ def preprocess_giga_speech(): # Run data augmentation that needs to be done in the # time domain. if partition not in ["DEV", "TEST"]: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 8 minutes and saving may take 20 minutes)" - ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + if args.perturb_speed: + logging.info( + f"Speed perturb for {partition} with factors 0.9 and 1.1 " + "(Perturbing may take 8 minutes and saving may take 20 minutes)" + ) + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) @@ -107,7 +128,8 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - preprocess_giga_speech() + args = get_args() + preprocess_giga_speech(args) if __name__ == "__main__": diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index a23b708d7..5e54b669a 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then exit 1; fi # Download XL, DEV and TEST sets by default. - lhotse download gigaspeech --subset auto --host tsinghua \ + lhotse download gigaspeech --subset XL \ + --subset L \ + --subset M \ + --subset S \ + --subset XS \ + --subset DEV \ + --subset TEST \ + --host tsinghua \ $dl_dir/password $dl_dir/GigaSpeech fi @@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # We assume that you have downloaded the GigaSpeech corpus # to $dl_dir/GigaSpeech mkdir -p data/manifests - lhotse prepare gigaspeech --subset auto -j $nj \ + lhotse prepare gigaspeech --subset XL \ + --subset L \ + --subset M \ + --subset S \ + --subset XS \ + --subset DEV \ + --subset TEST \ + -j $nj \ $dl_dir/GigaSpeech data/manifests fi @@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)" - python3 ./local/compute_fbank_gigaspeech_dev_test.py + log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech." + python3 ./local/compute_fbank_gigaspeech.py fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then @@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Prepare phone based lang" + log "Stage 9: Prepare transcript_words.txt and words.txt" lang_dir=data/lang_phone mkdir -p $lang_dir - - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/lm/lexicon.txt | - sort | uniq > $lang_dir/lexicon.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir - fi - if [ ! -f $lang_dir/transcript_words.txt ]; then gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \ | jq '.text' \ @@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then fi if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Prepare BPE based lang" + log "Stage 10: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Prepare BPE based lang" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} @@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then done fi -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Prepare bigram P" +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Prepare bigram P" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} @@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then done fi -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Prepare G" +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Prepare G" # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm @@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then fi fi -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Compile HLG" +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Compile HLG" ./local/compile_hlg.py --lang-dir data/lang_phone for vocab_size in ${vocab_sizes[@]}; do diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 850ab7c10..0501461cd 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule: group.add_argument( "--num-buckets", type=int, - default=30, + default=100, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -368,6 +368,8 @@ class GigaSpeechAsrDataModule: valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, shuffle=False, ) logging.info("About to create dev dataloader") @@ -417,6 +419,7 @@ class GigaSpeechAsrDataModule: logging.info( f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" ) + cuts_train = lhotse.combine( lhotse.load_manifest_lazy(p) for p in sorted_filenames ) diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index d93cc221c..c5335562c 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -416,6 +416,17 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + parser.add_argument( + "--scan-for-oom-batches", + type=str2bool, + default=False, + help=""" + Whether to scan for oom batches before training, this is helpful for + finding the suitable max_duration, you only need to run it once. + Caution: a little time consuming. + """, + ) + parser.add_argument( "--inf-check", type=str2bool, @@ -1171,9 +1182,16 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) + def remove_short_utt(c: Cut): + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + return T > 0 + gigaspeech = GigaSpeechAsrDataModule(args) train_cuts = gigaspeech.train_cuts() + train_cuts = train_cuts.filter(remove_short_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1187,9 +1205,10 @@ def run(rank, world_size, args): ) valid_cuts = gigaspeech.dev_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if not params.print_diagnostics and params.scan_for_oom_batches: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, diff --git a/egs/gigaspeech/KWS/RESULTS.md b/egs/gigaspeech/KWS/RESULTS.md new file mode 100644 index 000000000..992240e14 --- /dev/null +++ b/egs/gigaspeech/KWS/RESULTS.md @@ -0,0 +1,49 @@ +# Results + +## zipformer transducer model + +This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details. + +The modeling units are 500 BPEs trained on gigaspeech transcripts. + +The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test set of gigaspeech (has 40 hours audios). + +We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands. + +The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-gigaspeech-20240219.tar.gz). + +Here is the results of a small test set which has 20 commands, we list the results of every commands, for +each metric there are two columns, one for the original model trained on gigaspeech XL subset, the other +for the finetune model finetuned on commands dataset. + +Commands | FN in positive set |FN in positive set | Recall | Recall | FP in negative set | FP in negative set| False alarm (time / hour) 40 hours | False alarm (time / hour) 40 hours | +-- | -- | -- | -- | --| -- | -- | -- | -- +  | original | finetune | original | finetune | original | finetune | original | finetune +All | 43/307 | 4/307 | 86% | 98.7% | 1 | 24 | 0.025 | 0.6 +Lights on | 6/17 | 0/17 | 64.7% | 100% | 1 | 9 | 0.025 | 0.225 +Heat up | 5/14 | 1/14 | 64.3% | 92.9% | 0 | 1 | 0 | 0.025 +Volume down | 4/18 | 0/18 | 77.8% | 100% | 0 | 2 | 0 | 0.05 +Volume max | 4/17 | 0/17 | 76.5% | 100% | 0 | 0 | 0 | 0 +Volume mute | 4/16 | 0/16 | 75.0% | 100% | 0 | 0 | 0 | 0 +Too quiet | 3/17 | 0/17 | 82.4% | 100% | 0 | 4 | 0 | 0.1 +Lights off | 3/17 | 0/17 | 82.4% | 100% | 0 | 2 | 0 | 0.05 +Play music | 2/14 | 0/14 | 85.7% | 100% | 0 | 0 | 0 | 0 +Bring newspaper | 2/13 | 1/13 | 84.6% | 92.3% | 0 | 0 | 0 | 0 +Heat down | 2/16 | 2/16 | 87.5% | 87.5% | 0 | 1 | 0 | 0.025 +Volume up | 2/18 | 0/18 | 88.9% | 100% | 0 | 1 | 0 | 0.025 +Too loud | 1/13 | 0/13 | 92.3% | 100% | 0 | 0 | 0 | 0 +Resume music | 1/14 | 0/14 | 92.9% | 100% | 0 | 0 | 0 | 0 +Bring shoes | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 +Switch language | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 +Pause music | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 +Bring socks | 1/12 | 0/12 | 91.7% | 100% | 0 | 0 | 0 | 0 +Stop music | 0/15 | 0/15 | 100% | 100% | 0 | 0 | 0 | 0 +Turn it up | 0/15 | 0/15 | 100% | 100% | 0 | 3 | 0 | 0.075 +Turn it down | 0/16 | 0/16 | 100% | 100% | 0 | 1 | 0 | 0.025 + +This is the result of large test set, it has more than 200 commands, too many to list the details of each commands, so only an overall result here. + +Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours +-- | -- | -- | -- | -- | -- | -- | -- | -- +  | original | finetune | original | finetune | original | finetune | original | finetune +All | 622/3994 | 79/ 3994 | 83.6% | 97.9% | 18/19930 | 52/19930 | 0.45 | 1.3 diff --git a/egs/gigaspeech/KWS/prepare.sh b/egs/gigaspeech/KWS/prepare.sh new file mode 100755 index 000000000..0b098190d --- /dev/null +++ b/egs/gigaspeech/KWS/prepare.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare gigaspeech dataset." + mkdir -p data/fbank + if [ ! -e data/fbank/.gigaspeech.done ]; then + pushd ../ASR + ./prepare.sh --stage 0 --stop-stage 9 + ./prepare.sh --stage 11 --stop-stage 11 + popd + pushd data/fbank + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) . + ln -svf $(realpath ../ASR/data/fbank/XL_split) . + ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/musan_feats) . + popd + pushd data + ln -svf $(realpath ../ASR/data/lang_bpe_500) . + popd + touch data/fbank/.gigaspeech.done + else + log "Gigaspeech dataset already exists, skipping." + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare open commands dataset." + mkdir -p data/fbank + if [ ! -e data/fbank/.fluent_speech_commands.done ]; then + pushd data + git clone https://github.com/pkufool/open-commands.git + ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt + ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt + pushd open-commands + ./script/prepare.sh --stage 2 --stop-stage 2 + ./script/prepare.sh --stage 6 --stop-stage 6 + popd + popd + pushd data/fbank + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) . + popd + touch data/fbank/.fluent_speech_commands.done + else + log "Fluent speech commands dataset already exists, skipping." + fi +fi diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh new file mode 100755 index 000000000..ea04c7c9b --- /dev/null +++ b/egs/gigaspeech/KWS/run.sh @@ -0,0 +1,197 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +export CUDA_VISIBLE_DEVICES="0,1,2,3" +export PYTHONPATH=../../../:$PYTHONPATH + +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Train a model." + if [ ! -e data/fbank/.gigaspeech.done ]; then + log "You need to run the prepare.sh first." + exit -1 + fi + + python ./zipformer/train.py \ + --world-size 4 \ + --exp-dir zipformer/exp \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --num-epochs 12 \ + --lr-epochs 1.5 \ + --use-fp16 1 \ + --start-epoch 1 \ + --subset XL \ + --bpe-model data/lang_bpe_500/bpe.model \ + --causal 1 \ + --max-duration 1000 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decode the model." + for t in small, large; do + python ./zipformer/decode.py \ + --epoch 12 \ + --avg 2 \ + --exp-dir ./zipformer/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --test-set $t \ + --keywords-score 1.0 \ + --keywords-threshold 0.35 \ + --keywords-file ./data/commands_${t}.txt \ + --max-duration 3000 + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Export the model." + + python ./zipformer/export.py \ + --epoch 12 \ + --avg 2 \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 + + python ./zipformer/export_onnx_streaming.py \ + --exp-dir zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 12 \ + --avg 2 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 2: Finetune the model" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + base_lr=0.0005 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=zipformer/exp/pretrained.pt + + ./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --exp-dir zipformer/exp_finetune \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-fp16 1 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 \ + --base-lr $base_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 1500 +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 1: Decode the finetuned model." + for t in small, large; do + python ./zipformer/decode.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./zipformer/exp_finetune \ + --bpe-model data/lang_bpe_500/bpe.model \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --test-set $t \ + --keywords-score 1.0 \ + --keywords-threshold 0.35 \ + --keywords-file ./data/commands_${t}.txt \ + --max-duration 3000 + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 2: Export the finetuned model." + + python ./zipformer/export.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./zipformer/exp_finetune \ + --tokens data/lang_bpe_500/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 + + python ./zipformer/export_onnx_streaming.py \ + --exp-dir zipformer/exp_finetune \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 10 \ + --avg 2 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 +fi diff --git a/egs/gigaspeech/KWS/shared b/egs/gigaspeech/KWS/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/gigaspeech/KWS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py new file mode 100644 index 000000000..ccc602404 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py @@ -0,0 +1,477 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import glob +import inspect +import logging +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import lhotse +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class GigaSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + type=str, + default="XL", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--small-dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev (speeds up training)", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get train {self.args.subset} cuts") + if self.args.subset == "XL": + filenames = glob.glob( + f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" + ) + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + sorted_filenames = [f[1] for f in idx_filenames] + logging.info( + f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" + ) + + cuts_train = lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) + else: + path = ( + self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" + ) + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + ) + if self.args.small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + ) + + @lru_cache() + def fsc_train_cuts(self) -> CutSet: + logging.info("About to get fluent speech commands train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz" + ) + + @lru_cache() + def fsc_valid_cuts(self) -> CutSet: + logging.info("About to get fluent speech commands valid cuts") + return load_manifest_lazy( + self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def fsc_test_small_cuts(self) -> CutSet: + logging.info("About to get fluent speech commands small test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz" + ) + + @lru_cache() + def fsc_test_large_cuts(self) -> CutSet: + logging.info("About to get fluent speech commands large test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz" + ) diff --git a/egs/gigaspeech/KWS/zipformer/beam_search.py b/egs/gigaspeech/KWS/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/decode-asr.py b/egs/gigaspeech/KWS/zipformer/decode-asr.py new file mode 100755 index 000000000..149b8bed0 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/decode-asr.py @@ -0,0 +1,1066 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + test_cuts = gigaspeech.test_cuts() + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_fsc_cuts = gigaspeech.fsc_test_large_cuts() + test_fsc_dl = gigaspeech.test_dataloaders(test_fsc_cuts) + + test_sets = ["test", "fsc_test"] + test_dls = [test_dl, test_fsc_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py new file mode 100755 index 000000000..98b003937 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/decode.py @@ -0,0 +1,689 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --keywords-file keywords.txt \ + --beam-size 4 +""" + +import argparse +import logging +import math +import os +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from beam_search import ( + keywords_search, +) +from train import add_model_arguments, get_model, get_params + +from lhotse.cut import Cut +from icefall import ContextGraph +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +@dataclass +class KwMetric: + TP: int = 0 # True positive + FN: int = 0 # False negative + FP: int = 0 # False positive + TN: int = 0 # True negative + FN_list: List[str] = field(default_factory=list) + FP_list: List[str] = field(default_factory=list) + TP_list: List[str] = field(default_factory=list) + + def __str__(self) -> str: + return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})" + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--beam", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--keywords-file", + type=str, + help="File contains keywords.", + ) + + parser.add_argument( + "--test-set", + type=str, + default="small", + help="small or large", + ) + + parser.add_argument( + "--keywords-score", + type=float, + default=1.5, + help=""" + The default boosting score (token level) for keywords. it will boost the + paths that match keywords to make them survive beam search. + """, + ) + + parser.add_argument( + "--keywords-threshold", + type=float, + default=0.35, + help="The default threshold (probability) to trigger the keyword.", + ) + + parser.add_argument( + "--num-tailing-blanks", + type=int, + default=1, + help="The number of tailing blanks should have after hitting one keyword.", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + keywords_graph: Optional[ContextGraph] = None, +) -> List[List[Tuple[str, Tuple[int, int]]]]: + """Decode one batch and return the result in a list. + + The length of the list equals to batch size, the i-th element contains the + triggered keywords for the i-th utterance in the given batch. The triggered + keywords are also a list, each of it contains a tuple of hitting keyword and + the corresponding start timestamps and end timestamps of the hitting keyword. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + keywords_graph: + The graph containing keywords. + Returns: + Return the decoding result. See above description for the format of + the returned list. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + ans_dict = keywords_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + keywords_graph=keywords_graph, + beam=params.beam, + num_tailing_blanks=params.num_tailing_blanks, + blank_penalty=params.blank_penalty, + ) + + hyps = [] + for ans in ans_dict: + hyp = [] + for hit in ans: + hyp.append((hit.phrase, (hit.timestamps[0], hit.timestamps[-1]))) + hyps.append(hyp) + + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + keywords_graph: ContextGraph, + keywords: Set[str], + test_only_keywords: bool, +) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + keywords_graph: + The graph containing keywords. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 50 + + results = [] + metric = {"all": KwMetric()} + for k in keywords: + metric[k] = KwMetric() + + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps = decode_one_batch( + params=params, + model=model, + sp=sp, + keywords_graph=keywords_graph, + batch=batch, + ) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_text = ref_text.upper() + ref_words = ref_text.split() + hyp_words = [x[0] for x in hyp_words] + # for computing WER + this_batch.append((cut_id, ref_words, " ".join(hyp_words).split())) + hyp_set = set(hyp_words) # each item is a keyword phrase + if len(hyp_words) > 1: + logging.warning( + f"Cut {cut_id} triggers more than one keywords : {hyp_words}," + f"please check the transcript to see if it really has more " + f"than one keywords, if so consider splitting this audio and" + f"keep only one keyword for each audio." + ) + hyp_str = " | ".join( + hyp_words + ) # The triggered keywords for this utterance. + TP = False + FP = False + for x in hyp_set: + assert x in keywords, x # can only trigger keywords + if (test_only_keywords and x == ref_text) or ( + not test_only_keywords and x in ref_text + ): + TP = True + metric[x].TP += 1 + metric[x].TP_list.append(f"({ref_text} -> {x})") + if (test_only_keywords and x != ref_text) or ( + not test_only_keywords and x not in ref_text + ): + FP = True + metric[x].FP += 1 + metric[x].FP_list.append(f"({ref_text} -> {x})") + if TP: + metric["all"].TP += 1 + if FP: + metric["all"].FP += 1 + TN = True # all keywords are true negative then the summery is true negative. + FN = False + for x in keywords: + if x not in ref_text and x not in hyp_set: + metric[x].TN += 1 + continue + + TN = False + if (test_only_keywords and x == ref_text) or ( + not test_only_keywords and x in ref_text + ): + fn = True + for y in hyp_set: + if (test_only_keywords and y == ref_text) or ( + not test_only_keywords and y in ref_text + ): + fn = False + break + if fn: + FN = True + metric[x].FN += 1 + metric[x].FN_list.append(f"({ref_text} -> {hyp_str})") + if TN: + metric["all"].TN += 1 + if FN: + metric["all"].FN += 1 + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results, metric + + +def save_results( + params: AttributeDict, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], + metric: KwMetric, +): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" + + with open(metric_filename, "w") as of: + width = 10 + for key, item in sorted( + metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True + ): + acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) + precision = ( + 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP) + ) + recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN) + fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN) + s = f"{key}:\n" + s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" + s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" + s += f"\tAccuracy: {acc:.3f}\n" + s += f"\tPrecision: {precision:.3f}\n" + s += f"\tRecall(PPR): {recall:.3f}\n" + s += f"\tFPR: {fpr:.3f}\n" + s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n" + if key != "all": + s += f"\tTP list: {' # '.join(item.TP_list)}\n" + s += f"\tFP list: {' # '.join(item.FP_list)}\n" + s += f"\tFN list: {' # '.join(item.FN_list)}\n" + of.write(s + "\n") + if key == "all": + logging.info(s) + of.write(f"\n\n{params.keywords_config}") + + logging.info("Wrote metric stats to {}".format(metric_filename)) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "kws" + + params.suffix = params.test_set + if params.iter > 0: + params.suffix += f"-iter-{params.iter}-avg-{params.avg}" + else: + params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + params.suffix += f"-score-{params.keywords_score}" + params.suffix += f"-threshold-{params.keywords_threshold}" + params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" + if params.blank_penalty != 0: + params.suffix += f"-blank-penalty-{params.blank_penalty}" + params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + phrases = [] + token_ids = [] + keywords_scores = [] + keywords_thresholds = [] + keywords_config = [] + with open(params.keywords_file, "r") as f: + for line in f.readlines(): + keywords_config.append(line) + score = 0 + threshold = 0 + keyword = [] + words = line.strip().upper().split() + for word in words: + word = word.strip() + if word[0] == ":": + score = float(word[1:]) + continue + if word[0] == "#": + threshold = float(word[1:]) + continue + keyword.append(word) + keyword = " ".join(keyword) + phrases.append(keyword) + token_ids.append(sp.encode(keyword)) + keywords_scores.append(score) + keywords_thresholds.append(threshold) + + params.keywords_config = "".join(keywords_config) + + keywords_graph = ContextGraph( + context_score=params.keywords_score, ac_threshold=params.keywords_threshold + ) + keywords_graph.build( + token_ids=token_ids, + phrases=phrases, + scores=keywords_scores, + ac_thresholds=keywords_thresholds, + ) + keywords = set(phrases) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + test_cuts = gigaspeech.test_cuts() + test_dl = gigaspeech.test_dataloaders(test_cuts) + + if params.test_set == "small": + test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts() + test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts) + test_sets = ["small-fsc", "test"] + test_dls = [test_fsc_small_dl, test_dl] + else: + assert params.test_set == "large", params.test_set + test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts() + test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts) + test_sets = ["large-fsc", "test"] + test_dls = [test_fsc_large_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results, metric = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + keywords_graph=keywords_graph, + keywords=keywords, + test_only_keywords="fsc" in test_set, + ) + + save_results( + params=params, + test_set_name=test_set, + results=results, + metric=metric, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/KWS/zipformer/decoder.py b/egs/gigaspeech/KWS/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/encoder_interface.py b/egs/gigaspeech/KWS/zipformer/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/export.py b/egs/gigaspeech/KWS/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py new file mode 100755 index 000000000..b8e8802cb --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -0,0 +1,644 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For non-streaming model training: +./zipformer/finetune.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/fintune.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +from train import ( + add_model_arguments, + add_training_arguments, + compute_loss, + compute_validation_loss, + display_and_save_batch, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + scan_pessimistic_batches_for_oom, + set_batch_count, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + parser.add_argument( + "--continue-finetune", + type=str2bool, + default=False, + help="Continue finetuning or finetune from pre-trained model", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + add_training_arguments(parser) + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params) + 100000) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + + # if params.continue_finetune: + # set_batch_count(model, params.batch_idx_train) + # else: + # set_batch_count(model, params.batch_idx_train + 100000) + + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + + if params.continue_finetune: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_utt(c: Cut): + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + return T > 0 + + gigaspeech = GigaSpeechAsrDataModule(args) + + if params.use_mux: + train_cuts = CutSet.mux( + gigaspeech.train_cuts(), + gigaspeech.fsc_train_cuts(), + weights=[0.9, 0.1], + ) + else: + train_cuts = gigaspeech.fsc_train_cuts() + + train_cuts = train_cuts.filter(remove_short_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.fsc_valid_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py new file mode 120000 index 000000000..4ee54fff5 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py @@ -0,0 +1 @@ +../../ASR/zipformer/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/joiner.py b/egs/gigaspeech/KWS/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/model.py b/egs/gigaspeech/KWS/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/optim.py b/egs/gigaspeech/KWS/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/scaling.py b/egs/gigaspeech/KWS/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/subsampling.py b/egs/gigaspeech/KWS/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py new file mode 100755 index 000000000..e7387dd39 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -0,0 +1,1367 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="1,1,1,1,1,1", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="192,192,192,192,192,192", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="128,128,128,128,128,128", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="128,128,128,128,128,128", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=320, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=320, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=True, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + add_training_arguments(parser) + add_model_arguments(parser) + + return parser + + +def add_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=1, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--scan-for-oom-batches", + type=str2bool, + default=False, + help=""" + Whether to scan for oom batches before training, this is helpful for + finding the suitable max_duration, you only need to run it once. + Caution: a little time consuming. + """, + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 500, + "reset_interval": 2000, + "valid_interval": 20000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_utt(c: Cut): + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + return T > 0 + + gigaspeech = GigaSpeechAsrDataModule(args) + + train_cuts = gigaspeech.train_cuts() + train_cuts = train_cuts.filter(remove_short_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.dev_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/KWS/zipformer/zipformer.py b/egs/gigaspeech/KWS/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7fcd242fc..66c84b2a9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import warnings from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Union @@ -31,6 +32,7 @@ from icefall.rnn_lm.model import RnnLmModel from icefall.transformer_lm.model import TransformerLM from icefall.utils import ( DecodingResults, + KeywordResult, add_eos, add_sos, get_texts, @@ -789,6 +791,8 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor + ac_probs: Optional[List[float]] = None + # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded timestamp: List[int] = field(default_factory=list) @@ -805,6 +809,8 @@ class Hypothesis: # Context graph state context_state: Optional[ContextState] = None + num_tailing_blanks: int = 0 + @property def key(self) -> str: """Return a string representation of self.ys""" @@ -953,6 +959,241 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: return ans +def keywords_search( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + keywords_graph: ContextGraph, + beam: int = 4, + num_tailing_blanks: int = 0, + blank_penalty: float = 0, +) -> List[List[KeywordResult]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + keywords_graph: + A instance of ContextGraph containing keywords and their configurations. + beam: + Number of active paths during the beam search. + num_tailing_blanks: + The number of tailing blanks a keyword should be followed, this is for the + scenario that a keyword will be the prefix of another. In most cases, you + can just set it to 0. + blank_penalty: + The score used to penalize blank probability. + Returns: + Return a list of list of KeywordResult. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert keywords_graph is not None + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=keywords_graph.root, + timestamp=[], + ac_probs=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + sorted_ans = [[] for _ in range(N)] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + probs = logits.softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs = probs.log() + + probs = probs.reshape(-1) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_probs = k2.RaggedTensor(shape=log_probs_shape, value=probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + hyp_probs = ragged_probs[i].tolist() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + new_ac_probs = hyp.ac_probs[:] + context_score = 0 + new_context_state = hyp.context_state + new_num_tailing_blanks = hyp.num_tailing_blanks + 1 + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + new_ac_probs.append(hyp_probs[topk_indexes[k]]) + ( + context_score, + new_context_state, + _, + ) = keywords_graph.forward_one_step(hyp.context_state, new_token) + new_num_tailing_blanks = 0 + if new_context_state.token == -1: # root + new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id] + + new_log_prob = topk_log_probs[k] + context_score + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ac_probs=new_ac_probs, + context_state=new_context_state, + num_tailing_blanks=new_num_tailing_blanks, + ) + B[i].add(new_hyp) + + top_hyp = B[i].get_most_probable(length_norm=True) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) + if matched: + ac_prob = ( + sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level + ) + if ( + matched + and top_hyp.num_tailing_blanks > num_tailing_blanks + and ac_prob >= matched_state.ac_threshold + ): + keyword = KeywordResult( + hyps=top_hyp.ys[-matched_state.level :], + timestamps=top_hyp.timestamp[-matched_state.level :], + phrase=matched_state.phrase, + ) + sorted_ans[i].append(keyword) + B[i] = HypothesisList() + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=keywords_graph.root, + timestamp=[], + ac_probs=[], + ) + ) + + B = B + finalized_B + + for i, hyps in enumerate(B): + top_hyp = hyps.get_most_probable(length_norm=True) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) + if matched: + ac_prob = ( + sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level + ) + if matched and ac_prob >= matched_state.ac_threshold: + keyword = KeywordResult( + hyps=top_hyp.ys[-matched_state.level :], + timestamps=top_hyp.timestamp[-matched_state.level :], + phrase=matched_state.phrase, + ) + sorted_ans[i].append(keyword) + + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + return ans + + def modified_beam_search( model: nn.Module, encoder_out: torch.Tensor, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 814390ad6..2ab051e83 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -479,14 +479,18 @@ class LibriSpeechAsrDataModule: @lru_cache() def gigaspeech_subset_small_cuts(self) -> CutSet: logging.info("About to get Gigaspeech subset-S cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") @lru_cache() def gigaspeech_dev_cuts(self) -> CutSet: logging.info("About to get Gigaspeech dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + ) @lru_cache() def gigaspeech_test_cuts(self) -> CutSet: logging.info("About to get Gigaspeech test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + ) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index d03970265..c0aedd725 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -602,11 +602,9 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids_with_bpe(texts) - if type(y) == list: - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) + y = graph_compiler.texts_to_ids(texts, sep="/") + y = k2.RaggedTensor(y).to(device) + with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, diff --git a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py new file mode 100644 index 000000000..334a6d023 --- /dev/null +++ b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +import torch +import lhotse +from pathlib import Path +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + fix_manifests, + validate_recordings_and_supervisions, +) +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--kaldi-dir", + type=str, + help="""The directory containing kaldi style manifest, namely wav.scp, text and segments. + """, + ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bank bins. + """, + ) + + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="""The directory where the lhotse manifests and features to write to. + """, + ) + + parser.add_argument( + "--dataset", + type=str, + help="""The name of dataset. + """, + ) + + parser.add_argument( + "--partition", + type=str, + help="""Could be something like train, valid, test and so on. + """, + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + parser.add_argument( + "--num-jobs", type=int, default=50, help="The num of jobs to extract feature." + ) + + return parser.parse_args() + + +def prepare_cuts(args): + logging.info(f"Prepare cuts from {args.kaldi_dir}.") + recordings, supervisions, _ = lhotse.load_kaldi_data_dir(args.kaldi_dir, 16000) + recordings, supervisions = fix_manifests(recordings, supervisions) + validate_recordings_and_supervisions(recordings, supervisions) + cuts = CutSet.from_manifests(recordings=recordings, supervisions=supervisions) + return cuts + + +def compute_feature(args, cuts): + extractor = Fbank(FbankConfig(num_mel_bins=args.num_mel_bins)) + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{args.dataset}_cuts_{args.partition}.jsonl.gz" + if (args.output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {cuts_filename}") + + if "train" in args.partition: + if args.perturb_speed: + logging.info(f"Doing speed perturb") + cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) + cuts = cuts.compute_and_store_features( + extractor=extractor, + storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}", + # when an executor is specified, make more partitions + num_jobs=args.num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cuts.to_file(args.output_dir / cuts_filename) + + +def main(args): + args.kaldi_dir = Path(args.kaldi_dir) + args.output_dir = Path(args.output_dir) + cuts = prepare_cuts(args) + compute_feature(args, cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + main(args) diff --git a/egs/wenetspeech/ASR/local/prepare_pinyin.py b/egs/wenetspeech/ASR/local/prepare_pinyin.py new file mode 100755 index 000000000..ae40f1cdd --- /dev/null +++ b/egs/wenetspeech/ASR/local/prepare_pinyin.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes as input `lang_dir`, which should contain:: + - lang_dir/words.txt +and generates the following files in the directory `lang_dir`: + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" +import argparse +import re +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) +from icefall.utils import text_to_pinyin + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Prepare lang for pinyin", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("--lang-dir", type=str, help="The lang directory.") + + parser.add_argument( + "--token-type", + default="full_with_tone", + type=str, + help="""The type of pinyin, should be in: + full_with_tone: zhōng guó + full_no_tone: zhong guo + partial_with_tone: zh ōng g uó + partial_no_tone: zh ong g uo + """, + ) + + parser.add_argument( + "--pinyin-errors", + default="split", + type=str, + help="""How to handle characters that has no pinyin, + see `text_to_pinyin` in icefall/utils.py for details + """, + ) + + return parser + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: + """Check if all the given tokens are in token symbol table. + Args: + token_sym_table: + Token symbol table that contains all the valid tokens. + tokens: + A list of tokens. + Returns: + Return True if there is any token not in the token_sym_table, + otherwise False. + """ + for tok in tokens: + if tok not in token_sym_table: + return True + return False + + +def generate_lexicon( + args, token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: + """Generate a lexicon from a word list and token_sym_table. + Args: + token_sym_table: + Token symbol table that mapping token to token ids. + words: + A list of strings representing words. + Returns: + Return a dict whose keys are words and values are the corresponding + tokens. + """ + lexicon = [] + for word in words: + tokens = text_to_pinyin( + word.strip(), mode=args.token_type, errors=args.pinyin_errors + ) + if contain_oov(token_sym_table, tokens): + print(f"Word : {word} contains OOV token, skipping.") + continue + lexicon.append((word, tokens)) + + # The OOV word is + lexicon.append(("", [""])) + return lexicon + + +def generate_tokens(args, words: List[str]) -> Dict[str, int]: + """Generate tokens from the given word list. + Args: + words: + A list that contains words to generate tokens. + Returns: + Return a dict whose keys are tokens and values are token ids ranged + from 0 to len(keys) - 1. + """ + tokens: Dict[str, int] = dict() + tokens[""] = 0 + tokens[""] = 1 + tokens[""] = 2 + for word in words: + word = word.strip() + tokens_list = text_to_pinyin( + word, mode=args.token_type, errors=args.pinyin_errors + ) + for token in tokens_list: + if token not in tokens: + tokens[token] = len(tokens) + return tokens + + +def main(): + parser = get_parser() + args = parser.parse_args() + + lang_dir = Path(args.lang_dir) + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", "", "#0", "", ""] + for w in excluded: + if w in words: + words.remove(w) + + token_sym_table = generate_tokens(args, words) + + lexicon = generate_lexicon(args, token_sym_table, words) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index b0525de60..543d19ce0 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -364,4 +364,18 @@ if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ --vocab-size 5537 \ --master-port 12340 -fi \ No newline at end of file +fi + +if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then + log "Stage 22: Prepare pinyin based lang" + for token in full_with_tone partial_with_tone; do + lang_dir=data/lang_${token} + if [ ! -f $lang_dir/tokens.txt ]; then + cp data/lang_char/words.txt $lang_dir/words.txt + python local/prepare_pinyin.py \ + --token-type $token \ + --lang-dir $lang_dir + fi + python ./local/compile_lg.py --lang-dir $lang_dir + done +fi diff --git a/egs/wenetspeech/KWS/RESULTS.md b/egs/wenetspeech/KWS/RESULTS.md new file mode 100644 index 000000000..29da3e2e5 --- /dev/null +++ b/egs/wenetspeech/KWS/RESULTS.md @@ -0,0 +1,58 @@ +# Results + +## zipformer transducer model + +This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details. + +The modeling units are partial pinyin (i.e initials and finals) with tone. + +The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test net of wenetspeech (has 23 hours audios). + +We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands. + +The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-wenetspeech-20240219.tar.gz). + +Here is the results of a small test set which has 20 commands, we list the results of every commands, for +each metric there are two columns, one for the original model trained on wenetspeech L subset, the other +for the finetune model finetuned on in house commands dataset (has 90 hours audio). + +> You can see that the performance of the original model is very poor, I think the reason is the test commands are all collected from real product scenarios which are very different from the scenarios wenetspeech dataset was collected. After finetuning, the performance improves a lot. + +Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours +-- | -- | -- | -- | -- | -- | -- | -- | -- +  | original | finetune | original | finetune | original | finetune | original | finetune +All | 426 / 985 | 40/985 | 56.8% | 95.9% | 7 | 1 | 0.3 | 0.04 +下一个 | 5/50 | 0/50 | 90% | 100% | 3 | 0 | 0.13 | 0 +开灯 | 19/49 | 2/49 | 61.2% | 95.9% | 0 | 0 | 0 | 0 +第一个 | 11/50 | 3/50 | 78% | 94% | 3 | 0 | 0.13 | 0 +声音调到最大 | 39/50 | 7/50 | 22% | 86% | 0 | 0 | 0 | 0 +暂停音乐 | 36/49 | 1/49 | 26.5% | 98% | 0 | 0 | 0 | 0 +暂停播放 | 33/49 | 2/49 | 32.7% | 95.9% | 0 | 0 | 0 | 0 +打开卧室灯 | 33/49 | 1/49 | 32.7% | 98% | 0 | 0 | 0 | 0 +关闭所有灯 | 27/50 | 0/50 | 46% | 100% | 0 | 0 | 0 | 0 +关灯 | 25/48 | 2/48 | 47.9% | 95.8% | 1 | 1 | 0.04 | 0.04 +关闭导航 | 25/48 | 1/48 | 47.9% | 97.9% | 0 | 0 | 0 | 0 +打开蓝牙 | 24/47 | 0/47 | 48.9% | 100% | 0 | 0 | 0 | 0 +下一首歌 | 21/50 | 1/50 | 58% | 98% | 0 | 0 | 0 | 0 +换一首歌 | 19/50 | 5/50 | 62% | 90% | 0 | 0 | 0 | 0 +继续播放 | 19/50 | 2/50 | 62% | 96% | 0 | 0 | 0 | 0 +打开闹钟 | 18/49 | 2/49 | 63.3% | 95.9% | 0 | 0 | 0 | 0 +打开音乐 | 17/49 | 0/49 | 65.3% | 100% | 0 | 0 | 0 | 0 +打开导航 | 17/48 | 0/49 | 64.6% | 100% | 0 | 0 | 0 | 0 +打开电视 | 15/50 | 0/49 | 70% | 100% | 0 | 0 | 0 | 0 +大点声 | 12/50 | 5/50 | 76% | 90% | 0 | 0 | 0 | 0 +小点声 | 11/50 | 6/50 | 78% | 88% | 0 | 0 | 0 | 0 + + +This is the result of large test set, it has more than 100 commands, too many to list the details of each commands, so only an overall result here. We also list the results of two weak up words 小云小云 (only test set)and 你好问问 (both training and test sets). For 你好问问, we have to finetune models, one is finetuned on 你好问问 and our in house commands data, the other finetuned on only 你好问问. Both models perform much better than original model, the one finetuned on only 你好问问 behaves slightly better than the other. + +> 小云小云 test set and 你好问问 training, dev and test sets are available at https://github.com/pkufool/open-commands + +Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours +-- | -- | -- | -- | -- | -- | -- | -- | -- +  | original | finetune | original | finetune | original | finetune | original | finetune +large | 2429/4505 | 477 / 4505 | 46.1% | 89.4% | 50 | 41 | 2.17 | 1.78 +小云小云(clean) | 30/100 | 40/100 | 70% | 60% | 0 | 0 | 0 | 0 +小云小云(noisy) | 118/350 | 154/350 | 66.3% | 56% | 0 | 0 | 0 | 0 +你好问问(finetune with all keywords data) | 2236/10641 | 678/10641 | 79% | 93.6% | 0 | 0 | 0 | 0 +你好问问(finetune with only 你好问问) | 2236/10641 | 249/10641 | 79% | 97.7% | 0 | 0 | 0 | 0 diff --git a/egs/wenetspeech/KWS/prepare.sh b/egs/wenetspeech/KWS/prepare.sh new file mode 100755 index 000000000..dcc65fab4 --- /dev/null +++ b/egs/wenetspeech/KWS/prepare.sh @@ -0,0 +1,90 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare wewetspeech dataset." + mkdir -p data/fbank + if [ ! -e data/fbank/.wewetspeech.done ]; then + pushd ../ASR + ./prepare.sh --stage 0 --stop-stage 17 + ./prepare.sh --stage 22 --stop-stage 22 + popd + pushd data/fbank + ln -svf $(realpath ../ASR/data/fbank/cuts_DEV.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_DEV.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_TEST_NET.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_TEST_MEETING.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_L.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/L_split_1000) . + ln -svf $(realpath ../ASR/data/fbank/cuts_M.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/M_split_1000) . + ln -svf $(realpath ../ASR/data/fbank/cuts_S.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/S_split_1000) . + ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/musan_feats) . + popd + pushd data + ln -svf $(realpath ../ASR/data/lang_partial_tone) . + popd + touch data/fbank/.wewetspeech.done + else + log "WenetSpeech dataset already exists, skipping." + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare open commands dataset." + mkdir -p data/fbank + if [ ! -e data/fbank/.cn_speech_commands.done ]; then + pushd data + git clone https://github.com/pkufool/open-commands.git + ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt + ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt + pushd open-commands + ./script/prepare.sh --stage 1 --stop-stage 1 + ./script/prepare.sh --stage 3 --stop-stage 5 + popd + popd + pushd data/fbank + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_large.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_large) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_small.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_small) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_dev.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_dev) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_test.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_test) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_train.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_train) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_clean.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_clean.lca) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_noisy.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_noisy.lca) . + popd + touch data/fbank/.cn_speech_commands.done + else + log "CN speech commands dataset already exists, skipping." + fi +fi diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh new file mode 100755 index 000000000..2bdd6a5f3 --- /dev/null +++ b/egs/wenetspeech/KWS/run.sh @@ -0,0 +1,201 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +export CUDA_VISIBLE_DEVICES="0,1,2,3" +export PYTHONPATH=../../../:$PYTHONPATH + +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Train a model." + if [ ! -e data/fbank/.wenetspeech.done ]; then + log "You need to run the prepare.sh first." + exit -1 + fi + + python ./zipformer/train.py \ + --world-size 4 \ + --exp-dir zipformer/exp \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --num-epochs 18 \ + --lr-epochs 1.5 \ + --use-fp16 1 \ + --start-epoch 1 \ + --training-subset L \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --lang-dir data/lang_partial_tone \ + --max-duration 1000 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decode the model." + for t in small, large; do + python ./zipformer/decode.py \ + --epoch 18 \ + --avg 2 \ + --exp-dir ./zipformer/exp \ + --tokens ./data/lang_partial_tone/tokens.txt \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --test-set $t \ + --keywords-score 1.5 \ + --keywords-threshold 0.1 \ + --keywords-file ./data/commands_${t}.txt \ + --max-duration 3000 + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Export the model." + + python ./zipformer/export.py \ + --epoch 18 \ + --avg 2 \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_partial_tone/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 + + python ./zipformer/export_onnx_streaming.py \ + --exp-dir zipformer/exp \ + --tokens data/lang_partial_tone/tokens.txt \ + --epoch 18 \ + --avg 2 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 2: Finetune the model" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + base_lr=0.0005 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=zipformer/exp/pretrained.pt + + ./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --exp-dir zipformer/exp_finetune + --lang-dir ./data/lang_partial_tone \ + --pinyin-type partial_with_tone \ + --use-fp16 1 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 \ + --base-lr $base_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 1500 +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 1: Decode the finetuned model." + for t in small, large; do + python ./zipformer/decode.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./zipformer/exp_finetune \ + --tokens ./data/lang_partial_tone/tokens.txt \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --test-set $t \ + --keywords-score 0.000001 \ + --keywords-threshold 0.35 \ + --keywords-file ./data/commands_${t}.txt \ + --max-duration 3000 + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 2: Export the finetuned model." + + python ./zipformer/export.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./zipformer/exp_finetune \ + --tokens data/lang_partial_tone/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 + + python ./zipformer/export_onnx_streaming.py \ + --exp-dir zipformer/exp_finetune \ + --tokens data/lang_partial_tone/tokens.txt \ + --epoch 10 \ + --avg 2 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 +fi diff --git a/egs/wenetspeech/KWS/shared b/egs/wenetspeech/KWS/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/wenetspeech/KWS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/asr_datamodule.py b/egs/wenetspeech/KWS/zipformer/asr_datamodule.py new file mode 100644 index 000000000..7de748c8e --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/asr_datamodule.py @@ -0,0 +1,459 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + load_manifest, + load_manifest_lazy, + set_caching_enabled, +) +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class WenetSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--training-subset", + type=str, + default="L", + help="The training subset for using", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=300000, + drop_last=True, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_dl.sampler.load_state_dict(sampler_state_dict) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + + valid_dl = DataLoader( + validate, + batch_size=None, + sampler=valid_sampler, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def test_meeting_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def cn_speech_commands_small_cuts(self) -> CutSet: + logging.info("About to get cn speech commands small cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cn_speech_commands_cuts_small.jsonl.gz" + ) + + @lru_cache() + def cn_speech_commands_large_cuts(self) -> CutSet: + logging.info("About to get cn speech commands large cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cn_speech_commands_cuts_large.jsonl.gz" + ) + + @lru_cache() + def nihaowenwen_dev_cuts(self) -> CutSet: + logging.info("About to get nihaowenwen dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "nihaowenwen_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def nihaowenwen_test_cuts(self) -> CutSet: + logging.info("About to get nihaowenwen test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "nihaowenwen_cuts_test.jsonl.gz" + ) + + @lru_cache() + def nihaowenwen_train_cuts(self) -> CutSet: + logging.info("About to get nihaowenwen train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "nihaowenwen_cuts_train.jsonl.gz" + ) + + @lru_cache() + def xiaoyun_clean_cuts(self) -> CutSet: + logging.info("About to get xiaoyun clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "xiaoyun_cuts_clean.jsonl.gz" + ) + + @lru_cache() + def xiaoyun_noisy_cuts(self) -> CutSet: + logging.info("About to get xiaoyun noisy cuts") + return load_manifest_lazy( + self.args.manifest_dir / "xiaoyun_cuts_noisy.jsonl.gz" + ) diff --git a/egs/wenetspeech/KWS/zipformer/beam_search.py b/egs/wenetspeech/KWS/zipformer/beam_search.py new file mode 120000 index 000000000..94033eebf --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/beam_search.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/decode-asr.py b/egs/wenetspeech/KWS/zipformer/decode-asr.py new file mode 100755 index 000000000..6425030eb --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/decode-asr.py @@ -0,0 +1,767 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) fast beam search (LG) +./zipformer/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + wenetspeech = WenetSpeechAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = wenetspeech.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) + + test_net_cuts = wenetspeech.test_net_cuts() + test_net_cuts = test_net_cuts.filter(remove_short_utt) + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) + + test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] + test_dls = [dev_dl, test_net_dl, test_meeting_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py new file mode 100755 index 000000000..5ed3c6c2c --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -0,0 +1,737 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from beam_search import ( + keywords_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + num_tokens, + setup_logger, + store_transcripts, + str2bool, + text_to_pinyin, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +@dataclass +class KwMetric: + TP: int = 0 # True positive + FN: int = 0 # False negative + FP: int = 0 # False positive + TN: int = 0 # True negative + FN_list: List[str] = field(default_factory=list) + FP_list: List[str] = field(default_factory=list) + TP_list: List[str] = field(default_factory=list) + + def __str__(self) -> str: + return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})" + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/lang_partial_tone/tokens.txt", + help="The path to the token.txt", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--pinyin-type", + type=str, + help="The type of pinyin used as the modeling units.", + ) + + parser.add_argument( + "--keywords-file", + type=str, + help="File contains keywords.", + ) + + parser.add_argument( + "--test-set", + type=str, + default="small", + help="small or large", + ) + + parser.add_argument( + "--keywords-score", + type=float, + default=1.5, + help=""" + The default boosting score (token level) for keywords. it will boost the + paths that match keywords to make them survive beam search. + """, + ) + + parser.add_argument( + "--keywords-threshold", + type=float, + default=0.35, + help="The default threshold (probability) to trigger the keyword.", + ) + + parser.add_argument( + "--num-tailing-blanks", + type=int, + default=1, + help="The number of tailing blanks should have after hitting one keyword.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, + keywords_graph: ContextGraph, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + ans_dict = keywords_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + keywords_graph=keywords_graph, + beam=params.beam_size, + num_tailing_blanks=8, + ) + + hyps = [] + for ans in ans_dict: + hyp = [] + for hit in ans: + hyp.append( + ( + hit.phrase, + (hit.timestamps[0], hit.timestamps[-1]), + ) + ) + hyps.append(hyp) + + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + keywords_graph: ContextGraph, + keywords: Set[str], + test_only_keywords: bool, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 20 + + results = [] + metric = {"all": KwMetric()} + for k in keywords: + metric[k] = KwMetric() + + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps = decode_one_batch( + params=params, + model=model, + keywords_graph=keywords_graph, + batch=batch, + ) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = list(ref_text) + hyp_words = [x[0] for x in hyp_words] + this_batch.append((cut_id, ref_words, list("".join(hyp_words)))) + hyp_set = set(hyp_words) + if len(hyp_words) > 1: + logging.warning( + f"Cut {cut_id} triggers more than one keywords : {hyp_words}," + f"please check the transcript to see if it really has more " + f"than one keywords, if so consider splitting this audio and" + f"keep only one keyword for each audio." + ) + hyp_str = " | ".join( + hyp_words + ) # The triggered keywords for this utterance. + TP = False + FP = False + for x in hyp_set: + assert x in keywords, x # can only trigger keywords + if (test_only_keywords and x == ref_text) or ( + not test_only_keywords and x in ref_text + ): + TP = True + metric[x].TP += 1 + metric[x].TP_list.append(f"({ref_text} -> {x})") + if (test_only_keywords and x != ref_text) or ( + not test_only_keywords and x not in ref_text + ): + FP = True + metric[x].FP += 1 + metric[x].FP_list.append(f"({ref_text} -> {x})") + if TP: + metric["all"].TP += 1 + if FP: + metric["all"].FP += 1 + TN = True # all keywords are true negative then the summery is true negative. + FN = False + for x in keywords: + if x not in ref_text and x not in hyp_set: + metric[x].TN += 1 + continue + + TN = False + if (test_only_keywords and x == ref_text) or ( + not test_only_keywords and x in ref_text + ): + fn = True + for y in hyp_set: + if (test_only_keywords and y == ref_text) or ( + not test_only_keywords and y in ref_text + ): + fn = False + break + if fn: + FN = True + metric[x].FN += 1 + metric[x].FN_list.append(f"({ref_text} -> {hyp_str})") + if TN: + metric["all"].TN += 1 + if FN: + metric["all"].FN += 1 + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results, metric + + +def save_results( + params: AttributeDict, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], + metric: KwMetric, +): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" + + with open(metric_filename, "w") as of: + width = 10 + for key, item in sorted( + metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True + ): + acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) + precision = ( + 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP) + ) + recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN) + fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN) + s = f"{key}:\n" + s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" + s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" + s += f"\tAccuracy: {acc:.3f}\n" + s += f"\tPrecision: {precision:.3f}\n" + s += f"\tRecall(PPR): {recall:.3f}\n" + s += f"\tFPR: {fpr:.3f}\n" + s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n" + if key != "all": + s += f"\tTP list: {' # '.join(item.TP_list)}\n" + s += f"\tFP list: {' # '.join(item.FP_list)}\n" + s += f"\tFN list: {' # '.join(item.FN_list)}\n" + of.write(s + "\n") + if key == "all": + logging.info(s) + of.write(f"\n\n{params.keywords_config}") + + logging.info("Wrote metric stats to {}".format(metric_filename)) + + +@torch.no_grad() +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "kws" + + params.suffix = params.test_set + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + params.suffix += f"-score-{params.keywords_score}" + params.suffix += f"-threshold-{params.keywords_threshold}" + params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" + if params.blank_penalty != 0: + params.suffix += f"-blank-penalty-{params.blank_penalty}" + params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + phrases = [] + token_ids = [] + keywords_scores = [] + keywords_thresholds = [] + keywords_config = [] + with open(params.keywords_file, "r") as f: + for line in f.readlines(): + keywords_config.append(line) + score = 0 + threshold = 0 + keyword = [] + words = line.strip().upper().split() + for word in words: + word = word.strip() + if word[0] == ":": + score = float(word[1:]) + continue + if word[0] == "#": + threshold = float(word[1:]) + continue + keyword.append(word) + keyword = "".join(keyword) + tmp_ids = [] + kws_py = text_to_pinyin(keyword, mode=params.pinyin_type) + for k in kws_py: + if k in token_table: + tmp_ids.append(token_table[k]) + else: + logging.warning(f"Containing OOV tokens, skipping line : {line}") + tmp_ids = [] + break + if tmp_ids: + logging.info(f"Adding keyword : {keyword}") + phrases.append(keyword) + token_ids.append(tmp_ids) + keywords_scores.append(score) + keywords_thresholds.append(threshold) + params.keywords_config = "".join(keywords_config) + + keywords_graph = ContextGraph( + context_score=params.keywords_score, ac_threshold=params.keywords_threshold + ) + keywords_graph.build( + token_ids=token_ids, + phrases=phrases, + scores=keywords_scores, + ac_thresholds=keywords_thresholds, + ) + keywords = set(phrases) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + wenetspeech = WenetSpeechAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + test_net_cuts = wenetspeech.test_net_cuts() + test_net_cuts = test_net_cuts.filter(remove_short_utt) + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + + cn_commands_small_cuts = wenetspeech.cn_speech_commands_small_cuts() + cn_commands_small_cuts = cn_commands_small_cuts.filter(remove_short_utt) + cn_commands_small_dl = wenetspeech.test_dataloaders(cn_commands_small_cuts) + + cn_commands_large_cuts = wenetspeech.cn_speech_commands_large_cuts() + cn_commands_large_cuts = cn_commands_large_cuts.filter(remove_short_utt) + cn_commands_large_dl = wenetspeech.test_dataloaders(cn_commands_large_cuts) + + nihaowenwen_test_cuts = wenetspeech.nihaowenwen_test_cuts() + nihaowenwen_test_cuts = nihaowenwen_test_cuts.filter(remove_short_utt) + nihaowenwen_test_dl = wenetspeech.test_dataloaders(nihaowenwen_test_cuts) + + xiaoyun_clean_cuts = wenetspeech.xiaoyun_clean_cuts() + xiaoyun_clean_cuts = xiaoyun_clean_cuts.filter(remove_short_utt) + xiaoyun_clean_dl = wenetspeech.test_dataloaders(xiaoyun_clean_cuts) + + xiaoyun_noisy_cuts = wenetspeech.xiaoyun_noisy_cuts() + xiaoyun_noisy_cuts = xiaoyun_noisy_cuts.filter(remove_short_utt) + xiaoyun_noisy_dl = wenetspeech.test_dataloaders(xiaoyun_noisy_cuts) + + test_sets = [] + test_dls = [] + if params.test_set == "large": + test_sets += ["cn_commands_large", "test_net"] + test_dls += [cn_commands_large_dl, test_net_dl] + else: + assert params.test_set == "small", params.test_set + test_sets += [ + "cn_commands_small", + "nihaowenwen", + "xiaoyun_clean", + "xiaoyun_noisy", + "test_net", + ] + test_dls += [ + cn_commands_small_dl, + nihaowenwen_test_dl, + xiaoyun_clean_dl, + xiaoyun_noisy_dl, + test_net_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dls): + results, metric = decode_dataset( + dl=test_dl, + params=params, + model=model, + keywords_graph=keywords_graph, + keywords=keywords, + test_only_keywords="test_net" not in test_set, + ) + + save_results( + params=params, + test_set_name=test_set, + results=results, + metric=metric, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/KWS/zipformer/decoder.py b/egs/wenetspeech/KWS/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/encoder_interface.py b/egs/wenetspeech/KWS/zipformer/encoder_interface.py new file mode 120000 index 000000000..2c56d3d18 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py b/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/export.py b/egs/wenetspeech/KWS/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py new file mode 100755 index 000000000..6f34989e2 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model finetuning: +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For non-streaming model finetuning with mux (original dataset): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --use-mux 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model finetuning: +./zipformer/fintune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +# For streaming model finetuning with mux (original dataset): +./zipformer/fintune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + text_to_pinyin, +) + +from train import ( + add_model_arguments, + add_training_arguments, + compute_validation_loss, + display_and_save_batch, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + scan_pessimistic_batches_for_oom, + set_batch_count, +) + + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + parser.add_argument( + "--continue-finetune", + type=str2bool, + default=False, + help="Continue finetuning or finetune from pre-trained model", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_partial_tone", + help="Path to the pinyin lang directory", + ) + + parser.add_argument( + "--pinyin-type", + type=str, + default="partial_with_tone", + help=""" + The style of the output pinyin, should be: + full_with_tone : zhōng guó + full_no_tone : zhong guo + partial_with_tone : zh ōng g uó + partial_no_tone : zh ong g uo + """, + ) + + parser.add_argument( + "--pinyin-errors", + default="split", + type=str, + help="""How to handle characters that has no pinyin, + see `text_to_pinyin` in icefall/utils.py for details + """, + ) + + add_training_arguments(parser) + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts, sep="/") + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params) + 100000) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + + if params.continue_finetune: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_utt(c: Cut): + if c.duration > 15: + return False + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + return T > 0 + + wenetspeech = WenetSpeechAsrDataModule(args) + + if params.use_mux: + train_cuts = CutSet.mux( + wenetspeech.train_cuts(), + wenetspeech.nihaowenwen_train_cuts(), + weights=[0.9, 0.1], + ) + else: + train_cuts = wenetspeech.nihaowenwen_train_cuts() + + def encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = "/".join( + text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) + ) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_utt) + train_cuts = train_cuts.map(encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = wenetspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = wenetspeech.nihaowenwen_dev_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_cuts = valid_cuts.map(encode_text) + valid_dl = wenetspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/wenetspeech/KWS/zipformer/joiner.py b/egs/wenetspeech/KWS/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/model.py b/egs/wenetspeech/KWS/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/optim.py b/egs/wenetspeech/KWS/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/scaling.py b/egs/wenetspeech/KWS/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/scaling_converter.py b/egs/wenetspeech/KWS/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/subsampling.py b/egs/wenetspeech/KWS/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py new file mode 100755 index 000000000..5be34ed99 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -0,0 +1,1401 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + text_to_pinyin, +) + + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="""Maximum left-contexts for causal training, measured in frames which will + be converted to a number of chunks. If splitting into chunks, + chunk left-context frames will be chosen randomly from this list; else not relevant.""", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def add_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--scan-for-oom-batches", + type=str2bool, + default=False, + help=""" + Whether to scan for oom batches before training, this is helpful for + finding the suitable max_duration, you only need to run it once. + Caution: a little time consuming. + """, + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_partial_tone", + help="Path to the pinyin lang directory", + ) + + parser.add_argument( + "--pinyin-type", + type=str, + default="partial_with_tone", + help=""" + The style of the output pinyin, should be: + full_with_tone : zhōng guó + full_no_tone : zhong guo + partial_with_tone : zh ōng g uó + partial_no_tone : zh ong g uo + """, + ) + + parser.add_argument( + "--pinyin-errors", + default="split", + type=str, + help="""How to handle characters that has no pinyin, + see `text_to_pinyin` in icefall/utils.py for details + """, + ) + + add_training_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts, sep="/") + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + wenetspeech = WenetSpeechAsrDataModule(args) + + train_cuts = wenetspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 15.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + def encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = "/".join( + text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) + ) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.map(encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = wenetspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = wenetspeech.valid_cuts() + valid_cuts = valid_cuts.map(encode_text) + valid_dl = wenetspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + The compiler to encode texts to ids. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = supervisions["text"] + y = graph_compiler.texts_to_ids(texts) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/wenetspeech/KWS/zipformer/zipformer.py b/egs/wenetspeech/KWS/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py index 5f9571d42..8c2355c87 100644 --- a/icefall/char_graph_compiler.py +++ b/icefall/char_graph_compiler.py @@ -54,7 +54,7 @@ class CharCtcTrainingGraphCompiler(object): self.sos_id = self.token_table[sos_token] self.eos_id = self.token_table[eos_token] - def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + def texts_to_ids(self, texts: List[str], sep: str = "") -> List[List[int]]: """Convert a list of texts to a list-of-list of token IDs. Args: @@ -63,36 +63,21 @@ class CharCtcTrainingGraphCompiler(object): An example containing two strings is given below: ['你好中国', '北京欢迎您'] + sep: + The separator of the items in one sequence, mainly no separator for + Chinese (one character a token), "/" for Chinese characters plus BPE + token and pinyin tokens. Returns: Return a list-of-list of token IDs. """ + assert sep in ("", "/"), sep ids: List[List[int]] = [] whitespace = re.compile(r"([ \t])") for text in texts: - text = re.sub(whitespace, "", text) - sub_ids = [ - self.token_table[txt] if txt in self.token_table else self.oov_id - for txt in text - ] - ids.append(sub_ids) - return ids - - def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]: - """Convert a list of texts (which include chars and bpes) - to a list-of-list of token IDs. - - Args: - texts: - It is a list of strings. - An example containing two strings is given below: - - [['你', '好', '▁C', 'hina'], ['北','京', '▁', 'welcome', '您'] - Returns: - Return a list-of-list of token IDs. - """ - ids: List[List[int]] = [] - for text in texts: - text = text.split("/") + if sep == "": + text = re.sub(whitespace, "", text) + else: + text = text.split(sep) sub_ids = [ self.token_table[txt] if txt in self.token_table else self.oov_id for txt in text diff --git a/icefall/context_graph.py b/icefall/context_graph.py index b3d7972a8..138bf4673 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -17,7 +17,7 @@ import os import shutil from collections import deque -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union class ContextState: @@ -31,6 +31,9 @@ class ContextState: node_score: float, output_score: float, is_end: bool, + level: int, + phrase: str = "", + ac_threshold: float = 1.0, ): """Create a ContextState. @@ -51,6 +54,15 @@ class ContextState: the output node for current node. is_end: True if current token is the end of a context. + level: + The distance from current node to root. + phrase: + The context phrase of current state, the value is valid only when + current state is end state (is_end == True). + ac_threshold: + The acoustic threshold (probability) of current context phrase, the + value is valid only when current state is end state (is_end == True). + Note: ac_threshold only used in keywords spotting. """ self.id = id self.token = token @@ -58,7 +70,10 @@ class ContextState: self.node_score = node_score self.output_score = output_score self.is_end = is_end + self.level = level self.next = {} + self.phrase = phrase + self.ac_threshold = ac_threshold self.fail = None self.output = None @@ -75,7 +90,7 @@ class ContextGraph: beam search. """ - def __init__(self, context_score: float): + def __init__(self, context_score: float, ac_threshold: float = 1.0): """Initialize a ContextGraph with the given ``context_score``. A root node will be created (**NOTE:** the token of root is hardcoded to -1). @@ -87,8 +102,12 @@ class ContextGraph: Note: This is just the default score for each token, the users can manually specify the context_score for each word/phrase (i.e. different phrase might have different token score). + ac_threshold: + The acoustic threshold (probability) to trigger the word/phrase, this argument + is used only when applying the graph to keywords spotting system. """ self.context_score = context_score + self.ac_threshold = ac_threshold self.num_nodes = 0 self.root = ContextState( id=self.num_nodes, @@ -97,6 +116,7 @@ class ContextGraph: node_score=0, output_score=0, is_end=False, + level=0, ) self.root.fail = self.root @@ -136,7 +156,13 @@ class ContextGraph: node.output_score += 0 if output is None else output.output_score queue.append(node) - def build(self, token_ids: List[Tuple[List[int], float]]): + def build( + self, + token_ids: List[List[int]], + phrases: Optional[List[str]] = None, + scores: Optional[List[float]] = None, + ac_thresholds: Optional[List[float]] = None, + ): """Build the ContextGraph from a list of token list. It first build a trie from the given token lists, then fill the fail arc for each trie node. @@ -145,52 +171,80 @@ class ContextGraph: Args: token_ids: - The given token lists to build the ContextGraph, it is a list of tuple of - token list and its customized score, the token list contains the token ids + The given token lists to build the ContextGraph, it is a list of + token list, the token list contains the token ids for a word/phrase. The token id could be an id of a char (modeling with single Chinese char) or an id of a BPE - (modeling with BPEs). The score is the total score for current token list, + (modeling with BPEs). + phrases: + The given phrases, they are the original text of the token_ids, the + length of `phrases` MUST be equal to the length of `token_ids`. + scores: + The customize boosting score(token level) for each word/phrase, 0 means using the default value (i.e. self.context_score). + It is a list of floats, and the length of `scores` MUST be equal to + the length of `token_ids`. + ac_thresholds: + The customize trigger acoustic threshold (probability) for each phrase, + 0 means using the default value (i.e. self.ac_threshold). It is + used only when this graph applied for the keywords spotting system. + The length of `ac_threshold` MUST be equal to the length of `token_ids`. Note: The phrases would have shared states, the score of the shared states is - the maximum value among all the tokens sharing this state. + the MAXIMUM value among all the tokens sharing this state. """ - for (tokens, score) in token_ids: + num_phrases = len(token_ids) + if phrases is not None: + assert len(phrases) == num_phrases, (len(phrases), num_phrases) + if scores is not None: + assert len(scores) == num_phrases, (len(scores), num_phrases) + if ac_thresholds is not None: + assert len(ac_thresholds) == num_phrases, (len(ac_thresholds), num_phrases) + + for index, tokens in enumerate(token_ids): + phrase = "" if phrases is None else phrases[index] + score = 0.0 if scores is None else scores[index] + ac_threshold = 0.0 if ac_thresholds is None else ac_thresholds[index] node = self.root # If has customized score using the customized token score, otherwise # using the default score - context_score = ( - self.context_score if score == 0.0 else round(score / len(tokens), 2) - ) + context_score = self.context_score if score == 0.0 else score + threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold for i, token in enumerate(tokens): node_next = {} if token not in node.next: self.num_nodes += 1 - node_id = self.num_nodes - token_score = context_score is_end = i == len(tokens) - 1 + node_score = node.node_score + context_score + node.next[token] = ContextState( + id=self.num_nodes, + token=token, + token_score=context_score, + node_score=node_score, + output_score=node_score if is_end else 0, + is_end=is_end, + level=i + 1, + phrase=phrase if is_end else "", + ac_threshold=threshold if is_end else 0.0, + ) else: # node exists, get the score of shared state. token_score = max(context_score, node.next[token].token_score) - node_id = node.next[token].id - node_next = node.next[token].next + node.next[token].token_score = token_score + node_score = node.node_score + token_score + node.next[token].node_score = node_score is_end = i == len(tokens) - 1 or node.next[token].is_end - node_score = node.node_score + token_score - node.next[token] = ContextState( - id=node_id, - token=token, - token_score=token_score, - node_score=node_score, - output_score=node_score if is_end else 0, - is_end=is_end, - ) - node.next[token].next = node_next + node.next[token].output_score = node_score if is_end else 0 + node.next[token].is_end = is_end + if i == len(tokens) - 1: + node.next[token].phrase = phrase + node.next[token].ac_threshold = threshold node = node.next[token] self._fill_fail_output() def forward_one_step( - self, state: ContextState, token: int - ) -> Tuple[float, ContextState]: + self, state: ContextState, token: int, strict_mode: bool = True + ) -> Tuple[float, ContextState, ContextState]: """Search the graph with given state and token. Args: @@ -198,9 +252,27 @@ class ContextGraph: The given token containing trie node to start. token: The given token. + strict_mode: + If the `strict_mode` is True, it can match multiple phrases simultaneously, + and will continue to match longer phrase after matching a shorter one. + If the `strict_mode` is False, it can only match one phrase at a time, + when it matches a phrase, then the state will fall back to root state + (i.e. forgetting all the history state and starting a new match). If + the matched state have multiple outputs (node.output is not None), the + longest phrase will be return. + For example, if the phrases are `he`, `she` and `shell`, the query is + `like shell`, when `strict_mode` is True, the query will match `he` and + `she` at token `e` and `shell` at token `l`, while when `strict_mode` + if False, the query can only match `she`(`she` is longer than `he`, so + `she` not `he`) at token `e`. + Caution: When applying this graph for keywords spotting system, the + `strict_mode` MUST be True. Returns: - Return a tuple of score and next state. + Return a tuple of boosting score for current state, next state and matched + state (if any). Note: Only returns the matched state with longest phrase of + current state, even if there are multiple matches phrases. If no phrase + matched, the matched state is None. """ node = None score = 0 @@ -224,7 +296,49 @@ class ContextGraph: # The score of the fail path score = node.node_score - state.node_score assert node is not None - return (score + node.output_score, node) + + # The matched node of current step, will only return the node with + # longest phrase if there are multiple phrases matches this step. + # None if no matched phrase. + matched_node = ( + node if node.is_end else (node.output if node.output is not None else None) + ) + if not strict_mode and node.output_score != 0: + # output_score != 0 means at least on phrase matched + assert matched_node is not None + output_score = ( + node.node_score + if node.is_end + else ( + node.node_score if node.output is None else node.output.node_score + ) + ) + return (score + output_score - node.node_score, self.root, matched_node) + assert (node.output_score != 0 and matched_node is not None) or ( + node.output_score == 0 and matched_node is None + ), ( + node.output_score, + matched_node, + ) + return (score + node.output_score, node, matched_node) + + def is_matched(self, state: ContextState) -> Tuple[bool, ContextState]: + """Whether current state matches any phrase (i.e. current state is the + end state or the output of current state is not None. + + Args: + state: + The given state(trie node). + + Returns: + Return a tuple of status and matched state. + """ + if state.is_end: + return True, state + else: + if state.output is not None: + return True, state.output + return False, None def finalize(self, state: ContextState) -> Tuple[float, ContextState]: """When reaching the end of the decoded sequence, we need to finalize @@ -366,7 +480,7 @@ class ContextGraph: return dot -def _test(queries, score): +def _test(queries, score, strict_mode): contexts_str = [ "S", "HE", @@ -381,11 +495,15 @@ def _test(queries, score): # test default score (1) contexts = [] + scores = [] + phrases = [] for s in contexts_str: - contexts.append(([ord(x) for x in s], score)) + contexts.append([ord(x) for x in s]) + scores.append(round(score / len(s), 2)) + phrases.append(s) context_graph = ContextGraph(context_score=1) - context_graph.build(contexts) + context_graph.build(token_ids=contexts, scores=scores, phrases=phrases) symbol_table = {} for contexts in contexts_str: @@ -402,7 +520,9 @@ def _test(queries, score): total_scores = 0 state = context_graph.root for q in query: - score, state = context_graph.forward_one_step(state, ord(q)) + score, state, phrase = context_graph.forward_one_step( + state, ord(q), strict_mode + ) total_scores += score score, state = context_graph.finalize(state) assert state.token == -1, state.token @@ -427,9 +547,22 @@ if __name__ == "__main__": "DHRHISQ": 4, # "HIS", "S" "THEN": 2, # "HE" } - _test(queries, 0) + _test(queries, 0, True) - # test custom score (5) + queries = { + "HEHERSHE": 7, # "HE", "HE", "S", "HE" + "HERSHE": 5, # "HE", "S", "HE" + "HISHE": 5, # "HIS", "HE" + "SHED": 3, # "S", "HE" + "SHELF": 3, # "S", "HE" + "HELL": 2, # "HE" + "HELLO": 2, # "HE" + "DHRHISQ": 3, # "HIS" + "THEN": 2, # "HE" + } + _test(queries, 0, False) + + # test custom score # S : 5 # HE : 5 (2.5 + 2.5) # SHE : 8.34 (5 + 1.67 + 1.67) @@ -450,4 +583,17 @@ if __name__ == "__main__": "THEN": 5, # "HE" } - _test(queries, 5) + _test(queries, 5, True) + + queries = { + "HEHERSHE": 20, # "HE", "HE", "S", "HE" + "HERSHE": 15, # "HE", "S", "HE" + "HISHE": 10.84, # "HIS", "HE" + "SHED": 10, # "S", "HE" + "SHELF": 10, # "S", "HE" + "HELL": 5, # "HE" + "HELLO": 5, # "HE" + "DHRHISQ": 5.84, # "HIS" + "THEN": 5, # "HE" + } + _test(queries, 5, False) diff --git a/icefall/utils.py b/icefall/utils.py index a9e8a81b9..7d722b1bc 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -28,6 +28,8 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path +from pypinyin import pinyin, lazy_pinyin +from pypinyin.contrib.tone_convert import to_initials, to_finals_tone, to_finals from shutil import copyfile from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union @@ -327,6 +329,19 @@ def encode_supervisions_otc( return supervision_segments, res, sorted_ids, sorted_verbatim_texts +@dataclass +class KeywordResult: + # timestamps[k] contains the frame number on which tokens[k] + # is decoded + timestamps: List[int] + + # hyps is the keyword, i.e., word IDs or token IDs + hyps: List[int] + + # The triggered phrase + phrase: str + + @dataclass class DecodingResults: # timestamps[i][k] contains the frame number on which tokens[i][k] @@ -1583,6 +1598,87 @@ def load_averaged_model( return model +def text_to_pinyin( + txt: str, mode: str = "full_with_tone", errors: str = "default" +) -> List[str]: + """ + Convert a Chinese text (might contain some latin characters) to pinyin sequence. + + Args: + txt: + The input Chinese text. + mode: + The style of the output pinyin, should be: + full_with_tone : zhōng guó + full_no_tone : zhong guo + partial_with_tone : zh ōng g uó + partial_no_tone : zh ong g uo + errors: + How to handle the characters (latin) that has no pinyin. + default : output the same as input. + split : split into single characters (i.e. alphabets) + + Return: + Return a list of str. + + Examples: + txt: 想吃KFC + output: ['xiǎng', 'chī', 'KFC'] # mode=full_with_tone; errors=default + output: ['xiǎng', 'chī', 'K', 'F', 'C'] # mode=full_with_tone; errors=split + output: ['xiang', 'chi', 'KFC'] # mode=full_no_tone; errors=default + output: ['xiang', 'chi', 'K', 'F', 'C'] # mode=full_no_tone; errors=split + output: ['x', 'iǎng', 'ch', 'ī', 'KFC'] # mode=partial_with_tone; errors=default + output: ['x', 'iang', 'ch', 'i', 'KFC'] # mode=partial_no_tone; errors=default + """ + + assert mode in ( + "full_with_tone", + "full_no_tone", + "partial_no_tone", + "partial_with_tone", + ), mode + + assert errors in ("default", "split"), errors + + txt = txt.strip() + res = [] + if "full" in mode: + if errors == "default": + py = pinyin(txt) if mode == "full_with_tone" else lazy_pinyin(txt) + else: + py = ( + pinyin(txt, errors=lambda x: list(x)) + if mode == "full_with_tone" + else lazy_pinyin(txt, errors=lambda x: list(x)) + ) + res = [x[0] for x in py] if mode == "full_with_tone" else py + else: + if errors == "default": + py = pinyin(txt) if mode == "partial_with_tone" else lazy_pinyin(txt) + else: + py = ( + pinyin(txt, errors=lambda x: list(x)) + if mode == "partial_with_tone" + else lazy_pinyin(txt, errors=lambda x: list(x)) + ) + py = [x[0] for x in py] if mode == "partial_with_tone" else py + for x in py: + initial = to_initials(x, strict=False) + final = ( + to_finals(x, strict=False) + if mode == "partial_no_tone" + else to_finals_tone(x, strict=False) + ) + if initial == "" and final == "": + res.append(x) + else: + if initial != "": + res.append(initial) + if final != "": + res.append(final) + return res + + def tokenize_by_bpe_model( sp: spm.SentencePieceProcessor, txt: str, From 819bb455392be1abe062d3faaa46f93c5d6d6ffd Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 22 Feb 2024 15:50:11 +0800 Subject: [PATCH 106/216] Add pypinyin to requirements (#1515) --- requirements-ci.txt | 1 + requirements.txt | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 6c74f688c..ebea04615 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -18,6 +18,7 @@ git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 kaldialign==0.7.1 num2words +pypinyin==0.50.0 sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 diff --git a/requirements.txt b/requirements.txt index a1a46ae64..dec20c6e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,10 @@ kaldialign num2words kaldi-decoder sentencepiece>=0.1.96 +pypinyin==0.50.0 tensorboard typeguard dill black==22.3.0 onnx==1.15.0 -onnxruntime==1.16.3 \ No newline at end of file +onnxruntime==1.16.3 From 2483b8b4da9abceba24c5ead6a7eb342a00978ee Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 22 Feb 2024 15:53:19 +0800 Subject: [PATCH 107/216] Zipformer recipe for SPGISpeech (#1449) --- egs/spgispeech/ASR/RESULTS.md | 97 +- .../asr_datamodule.py | 23 +- .../ASR/zipformer/asr_datamodule.py | 1 + egs/spgispeech/ASR/zipformer/beam_search.py | 1 + egs/spgispeech/ASR/zipformer/decode.py | 1052 +++++++++++++ egs/spgispeech/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + egs/spgispeech/ASR/zipformer/joiner.py | 1 + egs/spgispeech/ASR/zipformer/model.py | 1 + egs/spgispeech/ASR/zipformer/optim.py | 1 + egs/spgispeech/ASR/zipformer/pretrained.py | 382 +++++ egs/spgispeech/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + egs/spgispeech/ASR/zipformer/subsampling.py | 1 + egs/spgispeech/ASR/zipformer/train.py | 1365 +++++++++++++++++ egs/spgispeech/ASR/zipformer/zipformer.py | 1 + 16 files changed, 2912 insertions(+), 18 deletions(-) create mode 120000 egs/spgispeech/ASR/zipformer/asr_datamodule.py create mode 120000 egs/spgispeech/ASR/zipformer/beam_search.py create mode 100755 egs/spgispeech/ASR/zipformer/decode.py create mode 120000 egs/spgispeech/ASR/zipformer/decoder.py create mode 120000 egs/spgispeech/ASR/zipformer/encoder_interface.py create mode 120000 egs/spgispeech/ASR/zipformer/joiner.py create mode 120000 egs/spgispeech/ASR/zipformer/model.py create mode 120000 egs/spgispeech/ASR/zipformer/optim.py create mode 100755 egs/spgispeech/ASR/zipformer/pretrained.py create mode 120000 egs/spgispeech/ASR/zipformer/scaling.py create mode 120000 egs/spgispeech/ASR/zipformer/scaling_converter.py create mode 120000 egs/spgispeech/ASR/zipformer/subsampling.py create mode 100755 egs/spgispeech/ASR/zipformer/train.py create mode 120000 egs/spgispeech/ASR/zipformer/zipformer.py diff --git a/egs/spgispeech/ASR/RESULTS.md b/egs/spgispeech/ASR/RESULTS.md index de9e35c5a..f2da53193 100644 --- a/egs/spgispeech/ASR/RESULTS.md +++ b/egs/spgispeech/ASR/RESULTS.md @@ -1,5 +1,70 @@ ## Results +### SPGISpeech BPE training results (Zipformer Transducer) + +#### 2024-01-05 + +#### Zipformer encoder + embedding decoder + +Transducer: Zipformer encoder + stateless decoder. + +The WERs are: + +| | dev | val | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 2.08 | 2.14 | --epoch 30 --avg 10 | +| modified beam search | 2.05 | 2.09 | --epoch 30 --avg 10 --beam-size 4 | +| fast beam search | 2.07 | 2.17 | --epoch 30 --avg 10 --beam 20 --max-contexts 8 --max-states 64 | + +**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the +transcripts are orthographic or normalized. These WERs correspond to the normalized transcription +scenario. + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +python zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --num-workers 2 \ + --max-duration 1000 +``` + +The decoding command is: +``` +# greedy search +python ./zipformer/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method greedy_search + +# modified beam search +python ./zipformer/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search + +# fast beam search +python ./zipformer/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +``` + ### SPGISpeech BPE training results (Pruned Transducer) #### 2022-05-11 @@ -43,28 +108,28 @@ The decoding command is: ``` # greedy search ./pruned_transducer_stateless2/decode.py \ - --iter 696000 --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ - --decoding-method greedy_search + --iter 696000 --avg 10 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 100 \ + --decoding-method greedy_search # modified beam search ./pruned_transducer_stateless2/decode.py \ - --iter 696000 --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 + --iter 696000 --avg 10 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 # fast beam search ./pruned_transducer_stateless2/decode.py \ - --iter 696000 --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --iter 696000 --avg 10 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 ``` Pretrained model is available at diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 7cd6771ce..75c5385a7 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -102,6 +102,20 @@ class SPGISpeechAsrDataModule: help="Determines the maximum duration of a concatenated cut " "relative to the duration of the longest cut in a batch.", ) + group.add_argument( + "--drop-last", + type=str2bool, + default=False, + help="When enabled, the last batch will be dropped", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) group.add_argument( "--gap", type=float, @@ -143,7 +157,7 @@ class SPGISpeechAsrDataModule: group.add_argument( "--num-workers", type=int, - default=8, + default=2, help="The number of training dataloader workers that " "collect the batches.", ) @@ -176,7 +190,7 @@ class SPGISpeechAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: @@ -223,11 +237,13 @@ class SPGISpeechAsrDataModule: cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) else: train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) logging.info("Using DynamicBucketingSampler.") @@ -276,10 +292,12 @@ class SPGISpeechAsrDataModule: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, ) else: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, + return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( cuts_valid, @@ -303,6 +321,7 @@ class SPGISpeechAsrDataModule: input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False diff --git a/egs/spgispeech/ASR/zipformer/asr_datamodule.py b/egs/spgispeech/ASR/zipformer/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/beam_search.py b/egs/spgispeech/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/decode.py b/egs/spgispeech/ASR/zipformer/decode.py new file mode 100755 index 000000000..90d318919 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/decode.py @@ -0,0 +1,1052 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import SPGISpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + SPGISpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + spgispeech = SPGISpeechAsrDataModule(args) + + dev_cuts = spgispeech.dev_cuts() + val_cuts = spgispeech.val_cuts() + + dev_dl = spgispeech.test_dataloaders(dev_cuts) + val_dl = spgispeech.test_dataloaders(val_cuts) + + test_sets = ["dev", "val"] + test_dl = [dev_dl, val_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/spgispeech/ASR/zipformer/decoder.py b/egs/spgispeech/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/encoder_interface.py b/egs/spgispeech/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/joiner.py b/egs/spgispeech/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/model.py b/egs/spgispeech/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/optim.py b/egs/spgispeech/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/pretrained.py b/egs/spgispeech/ASR/zipformer/pretrained.py new file mode 100755 index 000000000..a562fb9f6 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/pretrained.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for spgispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp/epoch-xx.pt`. + +Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/spgispeech/ASR/zipformer/scaling.py b/egs/spgispeech/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/scaling_converter.py b/egs/spgispeech/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/subsampling.py b/egs/spgispeech/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py new file mode 100755 index 000000000..1709a2845 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/train.py @@ -0,0 +1,1365 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import SPGISpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + spgispeech = SPGISpeechAsrDataModule(args) + + train_cuts = spgispeech.train_cuts() + + # Ideally we should filter utterances that are too long or too short, + # but SPGISpeech contains regular length utterances so we don't need to + # do that. Here are the statistics of the training data (obtained by + # `train_cuts.describe()`): + + # Cuts count: 5886320 + # Total duration (hours): 15070.1 + # Speech duration (hours): 15070.1 (100.0%) + # *** + # Duration statistics (seconds): + # mean 9.2 + # std 2.8 + # min 4.6 + # 25% 6.9 + # 50% 8.9 + # 75% 11.2 + # 99% 16.0 + # 99.5% 16.3 + # 99.9% 16.6 + # max 16.7 + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = spgispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = spgispeech.dev_cuts() + valid_dl = spgispeech.valid_dataloaders(valid_cuts) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + SPGISpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/spgispeech/ASR/zipformer/zipformer.py b/egs/spgispeech/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/spgispeech/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 291d06056c241d2772402c3fe88db59d609e7488 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 Feb 2024 14:24:13 +0800 Subject: [PATCH 108/216] Support torch 2.2.1 for cpu docker. (#1516) --- .github/scripts/docker/generate_build_matrix.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index ed01bd740..7bb8ac676 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -43,12 +43,12 @@ def get_torchaudio_version(torch_version): def get_matrix(): - k2_version = "1.24.4.dev20240218" - kaldifeat_version = "1.25.4.dev20240218" - version = "1.4" + k2_version = "1.24.4.dev20240223" + kaldifeat_version = "1.25.4.dev20240223" + version = "20240223" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - torch_version += ["2.2.0"] + torch_version += ["2.2.0", "2.2.1"] matrix = [] for p in python_version: From d89f4ea1492ee7036368e9b46b66acd01a62c46a Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 29 Feb 2024 10:13:22 +0800 Subject: [PATCH 109/216] Use piper_phonemize as text tokenizer in ljspeech recipe (#1511) * use piper_phonemize as text tokenizer in ljspeech recipe * modify usage of tokenizer in vits/train.py * update docs --- docs/source/recipes/TTS/ljspeech/vits.rst | 6 +- docs/source/recipes/TTS/vctk/vits.rst | 4 +- egs/ljspeech/TTS/local/prepare_token_file.py | 66 +++------------- .../TTS/local/prepare_tokens_ljspeech.py | 11 ++- egs/ljspeech/TTS/prepare.sh | 16 ++-- egs/ljspeech/TTS/vits/export-onnx.py | 3 +- egs/ljspeech/TTS/vits/infer.py | 9 ++- egs/ljspeech/TTS/vits/test_onnx.py | 4 +- egs/ljspeech/TTS/vits/tokenizer.py | 77 ++++++++++++++----- egs/ljspeech/TTS/vits/train.py | 9 ++- requirements-tts.txt | 3 +- 11 files changed, 107 insertions(+), 101 deletions(-) diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index d08aa0f47..323d0adfc 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -1,11 +1,11 @@ -VITS +VITS-LJSpeech =============== This tutorial shows you how to train an VITS model with the `LJSpeech `_ dataset. .. note:: - + TTS related recipes require packages in ``requirements-tts.txt``. .. note:: @@ -120,4 +120,4 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following link: - - ``_ + - ``_ diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst index 34024a5ea..45ae9d9d2 100644 --- a/docs/source/recipes/TTS/vctk/vits.rst +++ b/docs/source/recipes/TTS/vctk/vits.rst @@ -1,11 +1,11 @@ -VITS +VITS-VCTK =============== This tutorial shows you how to train an VITS model with the `VCTK `_ dataset. .. note:: - + TTS related recipes require packages in ``requirements-tts.txt``. .. note:: diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py index df976804a..5b048b600 100755 --- a/egs/ljspeech/TTS/local/prepare_token_file.py +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -17,7 +17,7 @@ """ -This file reads the texts in given manifest and generates the file that maps tokens to IDs. +This file generates the file that maps tokens to IDs. """ import argparse @@ -25,80 +25,38 @@ import logging from pathlib import Path from typing import Dict -from lhotse import load_manifest +from piper_phonemize import get_espeak_map def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--manifest-file", - type=Path, - default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"), - help="Path to the manifest file", - ) - parser.add_argument( "--tokens", type=Path, default=Path("data/tokens.txt"), - help="Path to the tokens", + help="Path to the dict that maps the text tokens to IDs", ) return parser.parse_args() -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. +def get_token2id(filename: Path) -> Dict[str, int]: + """Get a dict that maps token to IDs, and save it to the given filename.""" + all_tokens = get_espeak_map() # token: [token_id] + all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()} + # sort by token_id + all_tokens = sorted(all_tokens.items(), key=lambda x: x[1]) - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_token2id(manifest_file: Path) -> Dict[str, int]: - """Return a dict that maps token to IDs.""" - extra_tokens = [ - "", # 0 for blank - "", # 1 for sos and eos symbols. - "", # 2 for OOV - ] - all_tokens = set() - - cut_set = load_manifest(manifest_file) - - for cut in cut_set: - # Each cut only contain one supervision - assert len(cut.supervisions) == 1, len(cut.supervisions) - for t in cut.tokens: - all_tokens.add(t) - - all_tokens = extra_tokens + list(all_tokens) - - token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} - return token2id + for token, token_id in all_tokens: + f.write(f"{token} {token_id}\n") if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - manifest_file = Path(args.manifest_file) out_file = Path(args.tokens) - - token2id = get_token2id(manifest_file) - write_mapping(out_file, token2id) + get_token2id(out_file) diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py index fcd0137a0..08fe7430e 100755 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -23,9 +23,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t import logging from pathlib import Path -import g2p_en import tacotron_cleaner.cleaners from lhotse import CutSet, load_manifest +from piper_phonemize import phonemize_espeak def prepare_tokens_ljspeech(): @@ -35,17 +35,20 @@ def prepare_tokens_ljspeech(): partition = "all" cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - g2p = g2p_en.G2p() new_cuts = [] for cut in cut_set: # Each cut only contains one supervision - assert len(cut.supervisions) == 1, len(cut.supervisions) + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) text = cut.supervisions[0].normalized_text # Text normalization text = tacotron_cleaner.cleaners.custom_english_cleaners(text) # Convert to phonemes - cut.tokens = g2p(text) + tokens_list = phonemize_espeak(text, "en-us") + tokens = [] + for t in tokens_list: + tokens.extend(t) + cut.tokens = tokens new_cuts.append(cut) new_cut_set = CutSet.from_cuts(new_cuts) diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index ed0a07f5e..cbf27bd42 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -30,7 +30,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then cd vits/monotonic_align python setup.py build_ext --inplace cd ../../ - else + else log "monotonic_align lib already built" fi fi @@ -80,6 +80,11 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare phoneme tokens for LJSpeech" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then ./local/prepare_tokens_ljspeech.py mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ @@ -113,13 +118,12 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Generate token file" - # We assume you have installed g2p_en and espnet_tts_frontend. + # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/tokens.txt ]; then - ./local/prepare_token_file.py \ - --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \ - --tokens data/tokens.txt + ./local/prepare_token_file.py --tokens data/tokens.txt fi fi diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index f82f9dbe9..c607f0114 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -218,8 +218,7 @@ def main(): params.update(vars(args)) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size logging.info(params) diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index cf0d20ae2..9e7c71c6d 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -130,14 +130,16 @@ def infer_dataset( batch_size = len(batch["tokens"]) tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) # tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) audio = batch["audio"] audio_lens = batch["audio_lens"].tolist() @@ -201,8 +203,7 @@ def main(): device = torch.device("cuda", 0) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size logging.info(f"Device: {device}") diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index fcbc1d663..4f46e8e6c 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -108,7 +108,9 @@ def main(): model = OnnxModel(args.model_filename) text = "I went there to see the land, the people and how their system works, end quote." - tokens = tokenizer.texts_to_token_ids([text]) + tokens = tokenizer.texts_to_token_ids( + [text], intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = torch.tensor(tokens) # (1, T) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) audio = model(tokens, tokens_lens) # (1, T') diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index b0afc6a04..9a5a9090e 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao) # # See ../../LICENSE for clarification regarding multiple authors # @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Dict, List -import g2p_en import tacotron_cleaner.cleaners +from piper_phonemize import phonemize_espeak from utils import intersperse @@ -38,21 +39,37 @@ class Tokenizer(object): id = int(info[0]) else: token, id = info[0], int(info[1]) + assert token not in self.token2id, token self.token2id[token] = id - self.blank_id = self.token2id[""] - self.oov_id = self.token2id[""] + # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md + self.pad_id = self.token2id["_"] # padding + self.sos_id = self.token2id["^"] # beginning of an utterance (bos) + self.eos_id = self.token2id["$"] # end of an utterance (eos) + self.space_id = self.token2id[" "] # word separator (whitespace) + self.vocab_size = len(self.token2id) - self.g2p = g2p_en.G2p() - - def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): + def texts_to_token_ids( + self, + texts: List[str], + intersperse_blank: bool = True, + add_sos: bool = False, + add_eos: bool = False, + lang: str = "en-us", + ) -> List[List[int]]: """ Args: texts: A list of transcripts. intersperse_blank: Whether to intersperse blanks in the token sequence. + add_sos: + Whether to add sos token at the start. + add_eos: + Whether to add eos token at the end. + lang: + Language argument passed to phonemize_espeak(). Returns: Return a list of token id list [utterance][token_id] @@ -63,30 +80,46 @@ class Tokenizer(object): # Text normalization text = tacotron_cleaner.cleaners.custom_english_cleaners(text) # Convert to phonemes - tokens = self.g2p(text) + tokens_list = phonemize_espeak(text, lang) + tokens = [] + for t in tokens_list: + tokens.extend(t) + token_ids = [] for t in tokens: - if t in self.token2id: - token_ids.append(self.token2id[t]) - else: - token_ids.append(self.oov_id) + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) if intersperse_blank: - token_ids = intersperse(token_ids, self.blank_id) + token_ids = intersperse(token_ids, self.pad_id) + if add_sos: + token_ids = [self.sos_id] + token_ids + if add_eos: + token_ids = token_ids + [self.eos_id] token_ids_list.append(token_ids) return token_ids_list def tokens_to_token_ids( - self, tokens_list: List[str], intersperse_blank: bool = True - ): + self, + tokens_list: List[str], + intersperse_blank: bool = True, + add_sos: bool = False, + add_eos: bool = False, + ) -> List[List[int]]: """ Args: tokens_list: A list of token list, each corresponding to one utterance. intersperse_blank: Whether to intersperse blanks in the token sequence. + add_sos: + Whether to add sos token at the start. + add_eos: + Whether to add eos token at the end. Returns: Return a list of token id list [utterance][token_id] @@ -96,13 +129,17 @@ class Tokenizer(object): for tokens in tokens_list: token_ids = [] for t in tokens: - if t in self.token2id: - token_ids.append(self.token2id[t]) - else: - token_ids.append(self.oov_id) + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) if intersperse_blank: - token_ids = intersperse(token_ids, self.blank_id) + token_ids = intersperse(token_ids, self.pad_id) + if add_sos: + token_ids = [self.sos_id] + token_ids + if add_eos: + token_ids = token_ids + [self.eos_id] token_ids_list.append(token_ids) diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 71c4224fa..6589b75ff 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) return audio, audio_lens, features, features_lens, tokens, tokens_lens @@ -742,8 +744,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size logging.info(params) diff --git a/requirements-tts.txt b/requirements-tts.txt index c30e23d54..eae50ba7b 100644 --- a/requirements-tts.txt +++ b/requirements-tts.txt @@ -3,4 +3,5 @@ matplotlib==3.8.2 cython==3.0.6 numba==0.58.1 g2p_en==2.1.0 -espnet_tts_frontend==0.0.3 \ No newline at end of file +espnet_tts_frontend==0.0.3 +# piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 From 7e2b561bbf201784c02bb197b25ff32289ab60e3 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 29 Feb 2024 10:57:38 +0800 Subject: [PATCH 110/216] Add recipe for fine-tuning Zipformer with adapter (#1512) --- egs/librispeech/ASR/README.md | 1 + .../ASR/zipformer_adapter/asr_datamodule.py | 1 + .../ASR/zipformer_adapter/beam_search.py | 1 + .../ASR/zipformer_adapter/decode.py | 1062 +++++++ .../zipformer_adapter/decode_gigaspeech.py | 1115 ++++++++ .../ASR/zipformer_adapter/decoder.py | 1 + .../zipformer_adapter/encoder_interface.py | 1 + .../ASR/zipformer_adapter/export-onnx.py | 621 ++++ .../ASR/zipformer_adapter/joiner.py | 1 + .../ASR/zipformer_adapter/model.py | 1 + .../ASR/zipformer_adapter/onnx_decode.py | 385 +++ .../ASR/zipformer_adapter/onnx_pretrained.py | 1 + .../ASR/zipformer_adapter/optim.py | 1 + .../ASR/zipformer_adapter/scaling.py | 1 + .../zipformer_adapter/scaling_converter.py | 1 + .../ASR/zipformer_adapter/subsampling.py | 1 + .../ASR/zipformer_adapter/train.py | 1541 ++++++++++ .../ASR/zipformer_adapter/zipformer.py | 2515 +++++++++++++++++ 18 files changed, 7251 insertions(+) create mode 120000 egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/beam_search.py create mode 100755 egs/librispeech/ASR/zipformer_adapter/decode.py create mode 100755 egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/decoder.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/encoder_interface.py create mode 100755 egs/librispeech/ASR/zipformer_adapter/export-onnx.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/joiner.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/model.py create mode 100755 egs/librispeech/ASR/zipformer_adapter/onnx_decode.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/optim.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/scaling.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/scaling_converter.py create mode 120000 egs/librispeech/ASR/zipformer_adapter/subsampling.py create mode 100755 egs/librispeech/ASR/zipformer_adapter/train.py create mode 100644 egs/librispeech/ASR/zipformer_adapter/zipformer.py diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 1c8930818..5c5a76917 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -35,6 +35,7 @@ The following table lists the differences among them. | `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | | `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty | | `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe | +| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py b/egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/beam_search.py b/egs/librispeech/ASR/zipformer_adapter/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py new file mode 100755 index 000000000..bfa4cc230 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/decode.py @@ -0,0 +1,1062 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +- To activate the adapter (test on the target domain) +set --use-adapter True + +- To deactivate the adapter (test on the original domain) +set --use-adapter False + +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + import pdb; pdb.set_trace() + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device), strict=False) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device), strict=False) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py new file mode 100755 index 000000000..903014f4a --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, add_finetune_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / (params.decoding_method + "_giga") + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts() + + dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts) + test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_adapter/decoder.py b/egs/librispeech/ASR/zipformer_adapter/decoder.py new file mode 120000 index 000000000..cab465d2b --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/decoder.py @@ -0,0 +1 @@ +../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/encoder_interface.py b/egs/librispeech/ASR/zipformer_adapter/encoder_interface.py new file mode 120000 index 000000000..aa5d0217a --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/export-onnx.py b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py new file mode 100755 index 000000000..a1fc41664 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py @@ -0,0 +1,621 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, add_finetune_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer_adapter/joiner.py b/egs/librispeech/ASR/zipformer_adapter/joiner.py new file mode 120000 index 000000000..444cb5f15 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/joiner.py @@ -0,0 +1 @@ +../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/model.py b/egs/librispeech/ASR/zipformer_adapter/model.py new file mode 120000 index 000000000..0c6fe6112 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/model.py @@ -0,0 +1 @@ +../zipformer/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py b/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py new file mode 100755 index 000000000..000cea163 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +2. Run this file + +./zipformer/onnx_decode.py \ + --exp-dir $repo/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats +from k2 import SymbolTable + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + +def decode_one_batch( + model: OnnxModel, token_table: SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + The token table. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + hyps = [token_ids_to_words(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + The token table. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = SymbolTable.from_file(args.tokens) + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts() + + dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts) + test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py b/egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py new file mode 120000 index 000000000..a085def83 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py @@ -0,0 +1 @@ +../zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/optim.py b/egs/librispeech/ASR/zipformer_adapter/optim.py new file mode 120000 index 000000000..207eecfcd --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/scaling.py b/egs/librispeech/ASR/zipformer_adapter/scaling.py new file mode 120000 index 000000000..58e4b0a0f --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/scaling.py @@ -0,0 +1 @@ +../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/scaling_converter.py b/egs/librispeech/ASR/zipformer_adapter/scaling_converter.py new file mode 120000 index 000000000..bc7c7b5e3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/scaling_converter.py @@ -0,0 +1 @@ +../zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/subsampling.py b/egs/librispeech/ASR/zipformer_adapter/subsampling.py new file mode 120000 index 000000000..d178adc2e --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/subsampling.py @@ -0,0 +1 @@ +../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py new file mode 100755 index 000000000..7f81ddd96 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -0,0 +1,1541 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Finetune non-streaming model using adapters: + +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --use-mux 0 \ + --use-adapters 1 \ + --adapter-dim 16 \ + --finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + 1000000 + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--use-adapters", + type=str2bool, + default=True, + help="If use adapter to finetune the model" + ) + + parser.add_argument( + "--adapter-dim", + type=int, + default=16, + help="The bottleneck dimension of the adapter" + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + use_adapters=params.use_adapters, + adapter_dim=params.adapter_dim, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + # set modules except adapters to eval mode + for name, m in model.named_modules(): + if "adapter" in name: + m.training = True + else: + m.training = False + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) + model.train() + # set modules except adapters to eval mode + for name, m in model.named_modules(): + if "adapter" in name: + m.training = True + else: + m.training = False + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + else: + # resuming training + assert params.start_epoch > 1, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # keep the original model untouched, only update the adapters + num_trainable = 0 + for name, p in model.named_parameters(): + if "adapter" in name: + p.requires_grad = True + num_trainable += p.numel() + else: + p.requires_grad = False + + logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100)) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts() + if params.use_mux: + librispeech_cuts = librispeech.train_all_shuf_cuts() + train_cuts = CutSet.mux( + gigaspeech_cuts, # num cuts = 688182 + librispeech_cuts, # num cuts = 843723 + weights=[688182, 843723], + stop_early=True, + ) + else: + train_cuts = gigaspeech_cuts + logging.info(train_cuts) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + + valid_sets = ["librispeech", "gigaspeech"] + valid_dls = [ + librispeech.valid_dataloaders(valid_cuts), + librispeech.valid_dataloaders(gigaspeech_dev_cuts), + ] + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py new file mode 100644 index 000000000..e4e26cd84 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -0,0 +1,2515 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, + Balancer, + BiasNorm, + ChunkCausalDepthwiseConv1d, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, + SwooshL, + SwooshR, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + use_adapters: insert adapters in the zipformer encoder + adapter_dim: the dimension of the adapters + """ + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + use_adapters: bool = False, + adapter_dim: int = 16, + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + self.use_adapters = use_adapters + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + use_adapters=use_adapters, + adapter_dim=adapter_dim, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + ) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + use_adapters: insert adapters in each layer + adapter_dim: the bottleneck dimension of the adapter + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + use_adapters: bool=False, + adapter_dim: int=16, + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + self.use_adapters = use_adapters + if use_adapters: + self.mid_adapter = AdapterModule( + embed_dim=embed_dim, + bottleneck_dim=adapter_dim, + ) + + # placed after the 1st self-attn module + self.post_sa_adapter = AdapterModule( + embed_dim=embed_dim, + bottleneck_dim=adapter_dim, + ) + + # placed after the 2nd convolution module + self.post_conv_adapter = AdapterModule( + embed_dim=embed_dim, + bottleneck_dim=adapter_dim, + ) + + # at the end of each layer + self.adapter = AdapterModule( + embed_dim=embed_dim, + bottleneck_dim=adapter_dim, + ) + else: + self.mid_adapter = None + self.post_sa_adapter = None + self.post_conv_adapter = None + self.adapter = None + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif not self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if self.use_adapters and self.post_sa_adapter is not None: + src = self.post_sa_adapter(src) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + if self.use_adapters and self.mid_adapter is not None: + src = self.mid_adapter(src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if self.use_adapters and self.post_conv_adapter is not None: + src = self.post_conv_adapter(src) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + if self.use_adapters and self.adapter is not None: + src = self.adapter(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.bypass(src_orig, src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + output, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + output, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + return output, new_states + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, downsample, dropout) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Downsample, go through encoder, upsample, in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); + True means masked position. May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + src_orig = src + src = self.downsample(src) + + src, new_states = self.encoder.streaming_forward( + src, + states=states, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src), new_states + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + left_context_len + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + + c = Zipformer2( + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +class AdapterModule(nn.Module): + def __init__( + self, + embed_dim: int=384, + bottleneck_dim: int=16, + ): + # The simplest adapter + super(AdapterModule, self).__init__() + self.embed_dim = embed_dim + self.bottleneck_dim = bottleneck_dim + self.activation = SwooshL() + + self.in_proj = nn.Linear(embed_dim, bottleneck_dim) + self.out_proj = nn.Linear(bottleneck_dim, embed_dim) + + def forward(self, x): + x_orig = x + x = self.activation(self.in_proj(x)) + x = self.out_proj(x) + return x_orig + x + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) From 2f102eb989a16fb7f221fdac214f607824f68d9e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 29 Feb 2024 11:41:18 +0800 Subject: [PATCH 111/216] Add CUDA docker image for torch 2.2.1 (#1521) --- .github/workflows/build-docker-image.yml | 2 +- .github/workflows/run-docker-image.yml | 2 +- docker/torch1.12.1-cuda11.3.dockerfile | 4 +- docker/torch1.13.0-cuda11.6.dockerfile | 4 +- docker/torch1.9.0-cuda10.2.dockerfile | 4 +- docker/torch2.0.0-cuda11.7.dockerfile | 4 +- docker/torch2.1.0-cuda11.8.dockerfile | 4 +- docker/torch2.1.0-cuda12.1.dockerfile | 4 +- docker/torch2.2.0-cuda11.8.dockerfile | 4 +- docker/torch2.2.0-cuda12.1.dockerfile | 4 +- docker/torch2.2.1-cuda11.8.dockerfile | 70 ++++++++++++++++++++++++ docker/torch2.2.1-cuda12.1.dockerfile | 70 ++++++++++++++++++++++++ docs/source/docker/intro.rst | 2 + 13 files changed, 160 insertions(+), 18 deletions(-) create mode 100644 docker/torch2.2.1-cuda11.8.dockerfile create mode 100644 docker/torch2.2.1-cuda12.1.dockerfile diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index d5081f7d8..f5796d114 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index 65ba2cd64..eab31cccc 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v2 diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index cb885e59e..33ecbf4d1 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.7 -ARG K2_VERSION="1.24.4.dev20240211+cuda11.3.torch1.12.1" -ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.3.torch1.12.1" +ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1" ARG TORCHAUDIO_VERSION="0.12.1+cu113" LABEL authors="Fangjun Kuang " diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index e238d87aa..b4d62b0bc 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.9 -ARG K2_VERSION="1.24.4.dev20240211+cuda11.6.torch1.13.0" -ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.6.torch1.13.0" +ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0" ARG TORCHAUDIO_VERSION="0.13.0+cu116" LABEL authors="Fangjun Kuang " diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index 26d45cafc..4d2d3058a 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.7 -ARG K2_VERSION="1.24.4.dev20240211+cuda10.2.torch1.9.0" -ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda10.2.torch1.9.0" +ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0" ARG TORCHAUDIO_VERSION="0.9.0" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index 02906e53b..ad23f8be7 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.10 -ARG K2_VERSION="1.24.4.dev20240211+cuda11.7.torch2.0.0" -ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.7.torch2.0.0" +ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0" ARG TORCHAUDIO_VERSION="2.0.0+cu117" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index c87305922..4e6812b83 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.10 -ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.1.0" -ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.1.0" +ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0" ARG TORCHAUDIO_VERSION="2.1.0+cu118" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index f4c297678..c7de4cf28 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.10 -ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.1.0" -ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.1.0" +ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0" ARG TORCHAUDIO_VERSION="2.1.0+cu121" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.2.0-cuda11.8.dockerfile b/docker/torch2.2.0-cuda11.8.dockerfile index c59661c27..0104ae870 100644 --- a/docker/torch2.2.0-cuda11.8.dockerfile +++ b/docker/torch2.2.0-cuda11.8.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.10 -ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.2.0" -ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.2.0" +ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0" ARG TORCHAUDIO_VERSION="2.2.0+cu118" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.2.0-cuda12.1.dockerfile b/docker/torch2.2.0-cuda12.1.dockerfile index 2c484efd5..ccd5265b2 100644 --- a/docker/torch2.2.0-cuda12.1.dockerfile +++ b/docker/torch2.2.0-cuda12.1.dockerfile @@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive # python 3.10 -ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.2.0" -ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.2.0" +ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0" ARG TORCHAUDIO_VERSION="2.2.0+cu121" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.2.1-cuda11.8.dockerfile b/docker/torch2.2.1-cuda11.8.dockerfile new file mode 100644 index 000000000..0528ba72f --- /dev/null +++ b/docker/torch2.2.1-cuda11.8.dockerfile @@ -0,0 +1,70 @@ +FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1" +ARG TORCHAUDIO_VERSION="2.2.1+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.2.1-cuda12.1.dockerfile b/docker/torch2.2.1-cuda12.1.dockerfile new file mode 100644 index 000000000..3cdbb16ec --- /dev/null +++ b/docker/torch2.2.1-cuda12.1.dockerfile @@ -0,0 +1,70 @@ +FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1" +ARG TORCHAUDIO_VERSION="2.2.1+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index 149970eff..1acaa3d4f 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -34,6 +34,8 @@ which will give you something like below: .. code-block:: bash + "torch2.2.1-cuda12.1" + "torch2.2.1-cuda11.8" "torch2.2.0-cuda12.1" "torch2.2.0-cuda11.8" "torch2.1.0-cuda12.1" From 58610b1bf600d29aeec8dee48e470e84b2563947 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 29 Feb 2024 17:31:28 +0800 Subject: [PATCH 112/216] Provides `README.md` for TTS recipes (#1491) * Update README.md --- egs/ljspeech/TTS/README.md | 38 ++++++++++++++++++++++++++++++++++++++ egs/vctk/TTS/README.md | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 egs/ljspeech/TTS/README.md create mode 100644 egs/vctk/TTS/README.md diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md new file mode 100644 index 000000000..80be5a315 --- /dev/null +++ b/egs/ljspeech/TTS/README.md @@ -0,0 +1,38 @@ +# Introduction + +This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. +A transcription is provided for each clip. +Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours. + +The texts were published between 1884 and 1964, and are in the public domain. +The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain. + +The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/). + +# VITS + +This recipe provides a VITS model trained on the LJSpeech dataset. + +Pretrained model can be found [here](https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28). + +For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html). + +The training command is given below: +``` +export CUDA_VISIBLE_DEVICES=0,1,2,3 +./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --max-duration 500 +``` + +To inference, use: +``` +./vits/infer.py \ + --exp-dir vits/exp \ + --epoch 1000 \ + --tokens data/tokens.txt +``` \ No newline at end of file diff --git a/egs/vctk/TTS/README.md b/egs/vctk/TTS/README.md new file mode 100644 index 000000000..c07516b77 --- /dev/null +++ b/egs/vctk/TTS/README.md @@ -0,0 +1,37 @@ +# Introduction + +This CSTR VCTK Corpus includes speech data uttered by 110 English speakers with various accents. Each speaker reads out about 400 sentences, which were selected from a newspaper, the rainbow passage and an elicitation paragraph used for the speech accent archive. +The newspaper texts were taken from Herald Glasgow, with permission from Herald & Times Group. Each speaker has a different set of the newspaper texts selected based a greedy algorithm that increases the contextual and phonetic coverage. +The details of the text selection algorithms are described in the following paper: [C. Veaux, J. Yamagishi and S. King, "The voice bank corpus: Design, collection and data analysis of a large regional accent speech database,"](https://doi.org/10.1109/ICSDA.2013.6709856). + +The above information is from the [CSTR VCTK website](https://datashare.ed.ac.uk/handle/10283/3443). + +# VITS + +This recipe provides a VITS model trained on the VCTK dataset. + +Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05), note that this model was pretrained on the Edinburgh DataShare VCTK dataset. + +For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html). + +The training command is given below: +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 350 +``` + +To inference, use: +``` +./vits/infer.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt \ + --max-duration 500 +``` \ No newline at end of file From 29b195a42e68fcbdd86de57f8edb0685886ffab5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 1 Mar 2024 19:53:58 +0800 Subject: [PATCH 113/216] Update export-onnx.py for vits to support sherpa-onnx. (#1524) --- egs/ljspeech/TTS/vits/README.md | 3 +- egs/ljspeech/TTS/vits/export-onnx.py | 152 ++++++++++++++++++++++++++- 2 files changed, 151 insertions(+), 4 deletions(-) diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md index 1141326b9..f2deed588 100644 --- a/egs/ljspeech/TTS/vits/README.md +++ b/egs/ljspeech/TTS/vits/README.md @@ -1,3 +1,4 @@ See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials. -Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29. +Training logs, Tensorboard logs, and checkpoints are uploaded to +https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28 diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index c607f0114..58b166368 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -91,7 +91,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key - meta.value = value + meta.value = str(value) onnx.save(model, filename) @@ -199,10 +199,15 @@ def export_model_onnx( ) meta_data = { - "model_type": "VITS", + "model_type": "vits", "version": "1", "model_author": "k2-fsa", - "comment": "VITS generator", + "comment": "icefall", # must be icefall for models from icefall + "language": "English", + "voice": "en-us", # Choose your language appropriately + "has_espeak": 1, + "n_speakers": 1, + "sample_rate": 22050, # Must match the real sample rate } logging.info(f"meta_data: {meta_data}") @@ -268,3 +273,144 @@ if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() + +""" +Supported languages. + +LJSpeech is using "en-us" from the second column. + +Pty Language Age/Gender VoiceName File Other Languages + 5 af --/M Afrikaans gmw/af + 5 am --/M Amharic sem/am + 5 an --/M Aragonese roa/an + 5 ar --/M Arabic sem/ar + 5 as --/M Assamese inc/as + 5 az --/M Azerbaijani trk/az + 5 ba --/M Bashkir trk/ba + 5 be --/M Belarusian zle/be + 5 bg --/M Bulgarian zls/bg + 5 bn --/M Bengali inc/bn + 5 bpy --/M Bishnupriya_Manipuri inc/bpy + 5 bs --/M Bosnian zls/bs + 5 ca --/M Catalan roa/ca + 5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr + 5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5) + 5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5) + 5 cs --/M Czech zlw/cs + 5 cv --/M Chuvash trk/cv + 5 cy --/M Welsh cel/cy + 5 da --/M Danish gmq/da + 5 de --/M German gmw/de + 5 el --/M Greek grk/el + 5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10) + 2 en-gb --/M English_(Great_Britain) gmw/en (en 2) + 5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4) + 5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5) + 5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9) + 5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5) + 2 en-us --/M English_(America) gmw/en-US (en 3) + 5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc + 5 eo --/M Esperanto art/eo + 5 es --/M Spanish_(Spain) roa/es + 5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6) + 5 et --/M Estonian urj/et + 5 eu --/M Basque eu + 5 fa --/M Persian ira/fa + 5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn + 5 fi --/M Finnish urj/fi + 5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8) + 5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8) + 5 fr-fr --/M French_(France) roa/fr (fr 5) + 5 ga --/M Gaelic_(Irish) cel/ga + 5 gd --/M Gaelic_(Scottish) cel/gd + 5 gn --/M Guarani sai/gn + 5 grc --/M Greek_(Ancient) grk/grc + 5 gu --/M Gujarati inc/gu + 5 hak --/M Hakka_Chinese sit/hak + 5 haw --/M Hawaiian map/haw + 5 he --/M Hebrew sem/he + 5 hi --/M Hindi inc/hi + 5 hr --/M Croatian zls/hr (hbs 5) + 5 ht --/M Haitian_Creole roa/ht + 5 hu --/M Hungarian urj/hu + 5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5) + 5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8) + 5 ia --/M Interlingua art/ia + 5 id --/M Indonesian poz/id + 5 io --/M Ido art/io + 5 is --/M Icelandic gmq/is + 5 it --/M Italian roa/it + 5 ja --/M Japanese jpx/ja + 5 jbo --/M Lojban art/jbo + 5 ka --/M Georgian ccs/ka + 5 kk --/M Kazakh trk/kk + 5 kl --/M Greenlandic esx/kl + 5 kn --/M Kannada dra/kn + 5 ko --/M Korean ko + 5 kok --/M Konkani inc/kok + 5 ku --/M Kurdish ira/ku + 5 ky --/M Kyrgyz trk/ky + 5 la --/M Latin itc/la + 5 lb --/M Luxembourgish gmw/lb + 5 lfn --/M Lingua_Franca_Nova art/lfn + 5 lt --/M Lithuanian bat/lt + 5 ltg --/M Latgalian bat/ltg + 5 lv --/M Latvian bat/lv + 5 mi --/M Māori poz/mi + 5 mk --/M Macedonian zls/mk + 5 ml --/M Malayalam dra/ml + 5 mr --/M Marathi inc/mr + 5 ms --/M Malay poz/ms + 5 mt --/M Maltese sem/mt + 5 mto --/M Totontepec_Mixe miz/mto + 5 my --/M Myanmar_(Burmese) sit/my + 5 nb --/M Norwegian_Bokmål gmq/nb (no 5) + 5 nci --/M Nahuatl_(Classical) azc/nci + 5 ne --/M Nepali inc/ne + 5 nl --/M Dutch gmw/nl + 5 nog --/M Nogai trk/nog + 5 om --/M Oromo cus/om + 5 or --/M Oriya inc/or + 5 pa --/M Punjabi inc/pa + 5 pap --/M Papiamento roa/pap + 5 piqd --/M Klingon art/piqd + 5 pl --/M Polish zlw/pl + 5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5) + 5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6) + 5 py --/M Pyash art/py + 5 qdb --/M Lang_Belta art/qdb + 5 qu --/M Quechua qu + 5 quc --/M K'iche' myn/quc + 5 qya --/M Quenya art/qya + 5 ro --/M Romanian roa/ro + 5 ru --/M Russian zle/ru + 5 ru-cl --/M Russian_(Classic) zle/ru-cl + 2 ru-lv --/M Russian_(Latvia) zle/ru-LV + 5 sd --/M Sindhi inc/sd + 5 shn --/M Shan_(Tai_Yai) tai/shn + 5 si --/M Sinhala inc/si + 5 sjn --/M Sindarin art/sjn + 5 sk --/M Slovak zlw/sk + 5 sl --/M Slovenian zls/sl + 5 smj --/M Lule_Saami urj/smj + 5 sq --/M Albanian ine/sq + 5 sr --/M Serbian zls/sr + 5 sv --/M Swedish gmq/sv + 5 sw --/M Swahili bnt/sw + 5 ta --/M Tamil dra/ta + 5 te --/M Telugu dra/te + 5 th --/M Thai tai/th + 5 tk --/M Turkmen trk/tk + 5 tn --/M Setswana bnt/tn + 5 tr --/M Turkish trk/tr + 5 tt --/M Tatar trk/tt + 5 ug --/M Uyghur trk/ug + 5 uk --/M Ukrainian zle/uk + 5 ur --/M Urdu inc/ur + 5 uz --/M Uzbek trk/uz + 5 vi --/M Vietnamese_(Northern) aav/vi + 5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central + 5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south + 5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8) + 5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8) +""" From 242002e0bd660fda6ca1067a326737ce86b8b4a3 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 4 Mar 2024 23:28:04 +0800 Subject: [PATCH 114/216] Strengthened style constraints (#1527) --- .github/workflows/style_check.yml | 8 ++++- .pre-commit-config.yaml | 2 +- .../onnx_check.py | 4 +-- .../export-for-ncnn.py | 2 +- .../ASR/local/preprocess_gigaspeech.py | 1 + .../pruned_transducer_stateless2/decode.py | 1 + egs/gigaspeech/ASR/zipformer/ctc_decode.py | 2 +- .../ASR/zipformer/streaming_decode.py | 2 +- egs/gigaspeech/KWS/zipformer/decode.py | 6 ++-- egs/gigaspeech/KWS/zipformer/finetune.py | 29 +++++++++---------- .../ASR/conformer_ctc3/test_model.py | 3 +- .../export-onnx.py | 2 +- .../jit_pretrained.py | 3 +- .../ASR/long_file_recog/recognize.py | 14 ++++----- .../lstm_transducer_stateless2/onnx_check.py | 3 +- .../pruned_transducer_stateless/my_profile.py | 3 +- .../onnx_decode.py | 3 +- .../onnx_check.py | 4 +-- .../onnx_decode.py | 3 +- .../my_profile.py | 8 ++--- .../onnx_decode.py | 3 +- .../pruned_transducer_stateless7/alignment.py | 1 - .../decode_gigaspeech.py | 6 ++-- .../generate_model_from_checkpoint.py | 3 +- .../my_profile.py | 8 ++--- .../onnx_decode.py | 3 +- .../test_model.py | 1 - .../ctc_guide_decode_bs.py | 2 +- .../lconv.py | 5 +--- .../onnx_pretrained.py | 2 +- .../onnx_wrapper.py | 1 + .../ncnn_custom_layer.py | 1 - .../streaming-ncnn-decode.py | 1 - .../ASR/tiny_transducer_ctc/decode.py | 3 +- egs/librispeech/ASR/zipformer/ctc_decode.py | 2 +- egs/librispeech/ASR/zipformer/model.py | 2 +- egs/librispeech/ASR/zipformer/my_profile.py | 14 ++++----- egs/librispeech/ASR/zipformer/onnx_decode.py | 5 ++-- .../ASR/zipformer/onnx_pretrained_ctc_H.py | 3 +- .../ASR/zipformer/onnx_pretrained_ctc_HL.py | 3 +- .../ASR/zipformer/onnx_pretrained_ctc_HLG.py | 3 +- egs/librispeech/ASR/zipformer/scaling.py | 11 +++---- .../ASR/zipformer/streaming_decode.py | 2 +- egs/librispeech/ASR/zipformer/subsampling.py | 4 +-- .../ASR/zipformer_adapter/decode.py | 20 +++++++++---- .../zipformer_adapter/decode_gigaspeech.py | 2 +- .../ASR/zipformer_adapter/export-onnx.py | 2 +- .../ASR/zipformer_adapter/onnx_decode.py | 7 +++-- .../ASR/zipformer_adapter/train.py | 12 +++++--- .../ASR/zipformer_adapter/zipformer.py | 14 ++++----- egs/must_c/ST/local/get_text.py | 2 +- egs/must_c/ST/local/get_words.py | 1 - egs/swbd/ASR/conformer_ctc/decode.py | 1 - egs/swbd/ASR/local/filter_empty_text.py | 2 +- .../jit_pretrained.py | 1 + egs/tedlium3/ASR/zipformer/model.py | 2 +- .../local/prepare_dataset_from_kaldi_dir.py | 7 +++-- egs/wenetspeech/ASR/local/prepare_pinyin.py | 1 + .../onnx_check.py | 4 +-- egs/wenetspeech/KWS/zipformer/decode.py | 4 +-- egs/wenetspeech/KWS/zipformer/finetune.py | 28 +++++++++--------- egs/wenetspeech/KWS/zipformer/train.py | 1 - egs/yesno/ASR/tdnn/jit_pretrained.py | 3 +- icefall/byte_utils.py | 1 - icefall/ctc/prepare_lang.py | 2 +- icefall/diagnostics.py | 2 +- icefall/profiler.py | 9 +++--- icefall/rnn_lm/export-onnx.py | 4 +-- icefall/utils.py | 4 +-- requirements.txt | 6 +++- 70 files changed, 168 insertions(+), 166 deletions(-) diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index fc1dcbfd4..1c37f13ed 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -49,7 +49,7 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 + python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1 # Click issue fixed in https://github.com/psf/black/pull/2966 - name: Run flake8 @@ -67,3 +67,9 @@ jobs: working-directory: ${{github.workspace}} run: | black --check --diff . + + - name: Run isort + shell: bash + working-directory: ${{github.workspace}} + run: | + isort --check --diff . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1bb38f6ba..5cb213327 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: # E121,E123,E126,E226,E24,E704,W503,W504 - repo: https://github.com/pycqa/isort - rev: 5.11.5 + rev: 5.10.1 hooks: - id: isort args: ["--profile=black"] diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py index 19c518eaf..f04537660 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py @@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging -from icefall import is_module_available +import torch from onnx_pretrained import OnnxModel -import torch +from icefall import is_module_available def get_parser(): diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index b210430c6..06a0fa96b 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -70,9 +70,9 @@ import logging from pathlib import Path import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled from tokenizer import Tokenizer -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index b6603f80d..a31685211 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -23,6 +23,7 @@ from pathlib import Path from lhotse import CutSet, SupervisionSegment from lhotse.recipes.utils import read_manifests_if_cached + from icefall.utils import str2bool # Similar text filtering and normalization procedure as in: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 72f74c968..ef430302d 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -76,6 +76,7 @@ from beam_search import ( ) from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index aa51036d5..651f20cb6 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -88,7 +88,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py index 7cada8c9d..cb3fd0dc7 100755 --- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py +++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py @@ -51,7 +51,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py index 98b003937..0df2ec356 100755 --- a/egs/gigaspeech/KWS/zipformer/decode.py +++ b/egs/gigaspeech/KWS/zipformer/decode.py @@ -42,12 +42,10 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule -from beam_search import ( - keywords_search, -) +from beam_search import keywords_search +from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params -from lhotse.cut import Cut from icefall import ContextGraph from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py index b8e8802cb..2cd7c868b 100755 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -76,6 +76,20 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from train import ( + add_model_arguments, + add_training_arguments, + compute_loss, + compute_validation_loss, + display_and_save_batch, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + scan_pessimistic_batches_for_oom, + set_batch_count, +) from icefall import diagnostics from icefall.checkpoint import remove_checkpoints @@ -95,21 +109,6 @@ from icefall.utils import ( str2bool, ) -from train import ( - add_model_arguments, - add_training_arguments, - compute_loss, - compute_validation_loss, - display_and_save_batch, - get_adjusted_batch_count, - get_model, - get_params, - load_checkpoint_if_available, - save_checkpoint, - scan_pessimistic_batches_for_oom, - set_batch_count, -) - LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/librispeech/ASR/conformer_ctc3/test_model.py b/egs/librispeech/ASR/conformer_ctc3/test_model.py index b97b7eed8..aa12d6f83 100755 --- a/egs/librispeech/ASR/conformer_ctc3/test_model.py +++ b/egs/librispeech/ASR/conformer_ctc3/test_model.py @@ -24,8 +24,7 @@ To run this file, do: """ import torch - -from train import get_params, get_ctc_model +from train import get_ctc_model, get_params def test_model(): diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index 1e59e0858..79728afa4 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -59,9 +59,9 @@ import onnx import torch import torch.nn as nn from decoder import Decoder +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from emformer import Emformer from scaling_converter import convert_scaled_to_non_scaled -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py index 58f587c91..1deecbfc7 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py @@ -39,7 +39,7 @@ Usage of this script: import argparse import logging import math -from typing import List +from typing import List, Optional import kaldifeat import sentencepiece as spm @@ -47,7 +47,6 @@ import torch import torchaudio from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from torch.nn.utils.rnn import pad_sequence -from typing import Optional, List def get_parser(): diff --git a/egs/librispeech/ASR/long_file_recog/recognize.py b/egs/librispeech/ASR/long_file_recog/recognize.py index 466253446..f4008c23b 100755 --- a/egs/librispeech/ASR/long_file_recog/recognize.py +++ b/egs/librispeech/ASR/long_file_recog/recognize.py @@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat """ import argparse -import torch.multiprocessing as mp -import torch -import torch.nn as nn import logging from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Tuple - from pathlib import Path +from typing import List, Optional, Tuple import k2 import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn from asr_datamodule import AsrDataModule from beam_search import ( fast_beam_search_one_best, greedy_search_batch, modified_beam_search, ) -from icefall.utils import AttributeDict, convert_timestamp, setup_logger from lhotse import CutSet, load_manifest_lazy from lhotse.cut import Cut -from lhotse.supervision import AlignmentItem from lhotse.serialization import SequentialJsonlWriter +from lhotse.supervision import AlignmentItem + +from icefall.utils import AttributeDict, convert_timestamp, setup_logger def get_parser(): diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py index c83f38b2a..85e0648d3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py @@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging +import torch from onnx_pretrained import OnnxModel from icefall import is_module_available -import torch - def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py index b844ba613..9762d878c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py @@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py import argparse import logging + import sentencepiece as spm import torch +from train import add_model_arguments, get_encoder_model, get_params from icefall.profiler import get_model_profile -from train import get_encoder_model, add_model_arguments, get_params def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py index 8134d43f8..a235d7b13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py @@ -75,8 +75,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 5ca4173c1..e2c1d6b5b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging -from icefall import is_module_available +import torch from onnx_pretrained import OnnxModel -import torch +from icefall import is_module_available def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py index 3b1c72cf1..f8fed9519 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py @@ -76,8 +76,7 @@ import torch import torch.nn as nn from asr_datamodule import AsrDataModule from librispeech import LibriSpeech - -from onnx_pretrained import greedy_search, OnnxModel +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py index 4bf773918..cf0598ca3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py @@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py import argparse import logging +from typing import Tuple + import sentencepiece as spm import torch - -from typing import Tuple +from scaling import BasicNorm, DoubleSwish from torch import Tensor, nn +from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params from icefall.profiler import get_model_profile -from scaling import BasicNorm, DoubleSwish -from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py index 6f26e34b5..b0f76317b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py @@ -82,8 +82,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py index bfb5fe609..ee8196c3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -20,7 +20,6 @@ from typing import List import k2 import torch - from beam_search import Hypothesis, HypothesisList, get_hyps_shape # The force alignment problem can be formulated as finding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py index b0e4be0d1..7095c3cc8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -107,9 +107,6 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn - -# from asr_datamodule import LibriSpeechAsrDataModule -from gigaspeech import GigaSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -120,6 +117,9 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) + +# from asr_datamodule import LibriSpeechAsrDataModule +from gigaspeech import GigaSpeechAsrDataModule from gigaspeech_scoring import asr_text_post_processing from train import add_model_arguments, get_params, get_transducer_model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py index 37edc0390..3fd14aa47 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py @@ -65,16 +65,15 @@ from typing import Dict, List import sentencepiece as spm import torch - from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) +from icefall.utils import str2bool def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py index 5a068b3b6..1416c6828 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py @@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py import argparse import logging +from typing import Tuple + import sentencepiece as spm import torch - -from typing import Tuple +from scaling import BasicNorm, DoubleSwish from torch import Tensor, nn +from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params from icefall.profiler import get_model_profile -from scaling import BasicNorm, DoubleSwish -from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py index 67585ee47..e00281239 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py @@ -75,8 +75,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py index cdf914df3..1f50eb309 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py @@ -24,7 +24,6 @@ To run this file, do: """ import torch - from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index 01ba7b711..e2f08abc6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -118,8 +118,8 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_transducer_model from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py index a902358ae..2faec7ade 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py @@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - ScaledConv1d, -) +from scaling import ActivationBalancer, ScaledConv1d class LConv(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py index 0ff110370..3a16985bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -52,7 +52,7 @@ import onnxruntime as ort import sentencepiece as spm import torch import torchaudio -from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence +from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence from icefall.utils import make_pad_mask diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py index 247da0949..07e97bbdb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py @@ -14,6 +14,7 @@ import torch from torch import nn + from icefall.utils import make_pad_mask diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py index 442a0a8af..451c35332 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py @@ -4,7 +4,6 @@ import ncnn import numpy as np - layer_list = [] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py index 999f7e0b4..06127607d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -42,7 +42,6 @@ import ncnn import torch import torchaudio from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature - from ncnn_custom_layer import RegisterCustomLayers diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py index 6c2bf9ea1..cc4471e2b 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -1,10 +1,11 @@ import argparse import logging import math +import pprint from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import pprint + import k2 import sentencepiece as spm import torch diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 4db50b981..1f0f9bfac 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -88,7 +88,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 73009d35c..86da3ab29 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -22,9 +22,9 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear class AsrModel(nn.Module): diff --git a/egs/librispeech/ASR/zipformer/my_profile.py b/egs/librispeech/ASR/zipformer/my_profile.py index ca20956fb..7e1fd777a 100755 --- a/egs/librispeech/ASR/zipformer/my_profile.py +++ b/egs/librispeech/ASR/zipformer/my_profile.py @@ -22,24 +22,24 @@ Usage: ./zipformer/my_profile.py import argparse import logging +from typing import Tuple + import sentencepiece as spm import torch - -from typing import Tuple -from torch import Tensor, nn - -from icefall.utils import make_pad_mask -from icefall.profiler import get_model_profile from scaling import BiasNorm +from torch import Tensor, nn from train import ( + add_model_arguments, get_encoder_embed, get_encoder_model, get_joiner_model, - add_model_arguments, get_params, ) from zipformer import BypassModule +from icefall.profiler import get_model_profile +from icefall.utils import make_pad_mask + def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py index 356c2a830..449294444 100755 --- a/egs/librispeech/ASR/zipformer/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer/onnx_decode.py @@ -77,11 +77,10 @@ from typing import List, Tuple import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats -from k2 import SymbolTable def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py index a77c3bf2a..114490599 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 import argparse import logging import math -from typing import List, Tuple +from typing import Dict, List, Tuple import k2 import kaldifeat -from typing import Dict import kaldifst import onnxruntime as ort import torch diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py index 6ef944514..f7d3e5253 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 import argparse import logging import math -from typing import List, Tuple +from typing import Dict, List, Tuple import k2 import kaldifeat -from typing import Dict import kaldifst import onnxruntime as ort import torch diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py index ccb3107ea..ebd385364 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 import argparse import logging import math -from typing import List, Tuple +from typing import Dict, List, Tuple import k2 import kaldifeat -from typing import Dict import kaldifst import onnxruntime as ort import torch diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c0f1e3087..29ac33c02 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -15,15 +15,16 @@ # limitations under the License. -from typing import Optional, Tuple, Union import logging -import k2 -from torch.cuda.amp import custom_fwd, custom_bwd -import random -import torch import math +import random +from typing import Optional, Tuple, Union + +import k2 +import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 8087c1460..360523b8e 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -51,7 +51,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index d16d87bac..b2f769d3f 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -16,11 +16,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import warnings +from typing import Tuple import torch -from torch import Tensor, nn from scaling import ( Balancer, BiasNorm, @@ -34,6 +33,7 @@ from scaling import ( SwooshR, Whiten, ) +from torch import Tensor, nn class ConvNeXt(nn.Module): diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py index bfa4cc230..91533be8d 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode.py @@ -858,7 +858,9 @@ def main(): logging.info("About to create model") model = get_model(params) - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() if not params.use_averaged_model: if params.iter > 0: @@ -877,9 +879,13 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device), strict=False) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False) + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) else: start = params.epoch - params.avg + 1 filenames = [] @@ -888,7 +894,9 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device), strict=False) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ @@ -917,7 +925,7 @@ def main(): filename_end=filename_end, device=device, ), - strict=False + strict=False, ) else: assert params.avg > 0, params.avg @@ -936,7 +944,7 @@ def main(): filename_end=filename_end, device=device, ), - strict=False + strict=False, ) model.to(device) diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py index 903014f4a..bbc582f50 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py @@ -121,7 +121,7 @@ from beam_search import ( modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) -from train import add_model_arguments, add_finetune_arguments, get_model, get_params +from train import add_finetune_arguments, add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( diff --git a/egs/librispeech/ASR/zipformer_adapter/export-onnx.py b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py index a1fc41664..ea29e8159 100755 --- a/egs/librispeech/ASR/zipformer_adapter/export-onnx.py +++ b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py @@ -72,7 +72,7 @@ import torch.nn as nn from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, add_finetune_arguments, get_model, get_params +from train import add_finetune_arguments, add_model_arguments, get_model, get_params from zipformer import Zipformer2 from icefall.checkpoint import ( diff --git a/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py b/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py index 000cea163..e3f7ce85a 100755 --- a/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py @@ -77,11 +77,10 @@ from typing import List, Tuple import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats -from k2 import SymbolTable conversational_filler = [ "UH", @@ -182,6 +181,7 @@ def get_parser(): return parser + def post_processing( results: List[Tuple[str, List[str], List[str]]], ) -> List[Tuple[str, List[str], List[str]]]: @@ -192,6 +192,7 @@ def post_processing( new_results.append((key, new_ref, new_hyp)) return new_results + def decode_one_batch( model: OnnxModel, token_table: SymbolTable, batch: dict ) -> List[List[str]]: diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 7f81ddd96..e64c10e7a 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -121,7 +121,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): default=True, help="If true, finetune from a pre-trained checkpoint", ) - + parser.add_argument( "--use-mux", type=str2bool, @@ -137,14 +137,14 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): "--use-adapters", type=str2bool, default=True, - help="If use adapter to finetune the model" + help="If use adapter to finetune the model", ) parser.add_argument( "--adapter-dim", type=int, default=16, - help="The bottleneck dimension of the adapter" + help="The bottleneck dimension of the adapter", ) parser.add_argument( @@ -1273,7 +1273,11 @@ def run(rank, world_size, args): else: p.requires_grad = False - logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100)) + logging.info( + "A total of {} trainable parameters ({:.3f}% of the whole model)".format( + num_trainable, num_trainable / num_param * 100 + ) + ) model.to(device) if world_size > 1: diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index e4e26cd84..4e4695fa5 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -40,13 +40,13 @@ from scaling import ( Dropout2, FloatLike, ScheduledFloat, + SwooshL, + SwooshR, Whiten, convert_num_channels, limit_param_value, penalize_abs_values_gt, softmax, - SwooshL, - SwooshR, ) from torch import Tensor, nn @@ -601,8 +601,8 @@ class Zipformer2EncoderLayer(nn.Module): bypass_skip_rate: FloatLike = ScheduledFloat( (0.0, 0.5), (4000.0, 0.02), default=0 ), - use_adapters: bool=False, - adapter_dim: int=16, + use_adapters: bool = False, + adapter_dim: int = 16, ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -737,7 +737,7 @@ class Zipformer2EncoderLayer(nn.Module): embed_dim=embed_dim, bottleneck_dim=adapter_dim, ) - + # placed after the 2nd convolution module self.post_conv_adapter = AdapterModule( embed_dim=embed_dim, @@ -2488,8 +2488,8 @@ def _test_zipformer_main(causal: bool = False): class AdapterModule(nn.Module): def __init__( self, - embed_dim: int=384, - bottleneck_dim: int=16, + embed_dim: int = 384, + bottleneck_dim: int = 16, ): # The simplest adapter super(AdapterModule, self).__init__() diff --git a/egs/must_c/ST/local/get_text.py b/egs/must_c/ST/local/get_text.py index 558ab6de8..f7b5816a8 100755 --- a/egs/must_c/ST/local/get_text.py +++ b/egs/must_c/ST/local/get_text.py @@ -5,9 +5,9 @@ This file prints the text field of supervisions from cutset to the console """ import argparse +from pathlib import Path from lhotse import load_manifest_lazy -from pathlib import Path def get_args(): diff --git a/egs/must_c/ST/local/get_words.py b/egs/must_c/ST/local/get_words.py index a61f60860..b32925099 100755 --- a/egs/must_c/ST/local/get_words.py +++ b/egs/must_c/ST/local/get_words.py @@ -5,7 +5,6 @@ This file generates words.txt from the given transcript file. """ import argparse - from pathlib import Path diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index 2bbade374..52e501ae1 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -29,7 +29,6 @@ import torch import torch.nn as nn from asr_datamodule import SwitchBoardAsrDataModule from conformer import Conformer - from sclite_scoring import asr_text_post_processing from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py index 6b3316800..13b35980b 100755 --- a/egs/swbd/ASR/local/filter_empty_text.py +++ b/egs/swbd/ASR/local/filter_empty_text.py @@ -16,8 +16,8 @@ # limitations under the License. import argparse -from pathlib import Path import logging +from pathlib import Path from typing import List diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py index 8c966a2f6..503cdf4ed 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -45,6 +45,7 @@ import sentencepiece as spm import torch import torchaudio from torch.nn.utils.rnn import pad_sequence + from icefall import smart_byte_decode diff --git a/egs/tedlium3/ASR/zipformer/model.py b/egs/tedlium3/ASR/zipformer/model.py index 90ec7e7aa..65b052ab9 100644 --- a/egs/tedlium3/ASR/zipformer/model.py +++ b/egs/tedlium3/ASR/zipformer/model.py @@ -19,9 +19,9 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear class Transducer(nn.Module): diff --git a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py index 334a6d023..52da3d6dc 100644 --- a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py +++ b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py @@ -17,10 +17,10 @@ import argparse import logging - -import torch -import lhotse from pathlib import Path + +import lhotse +import torch from lhotse import ( CutSet, Fbank, @@ -29,6 +29,7 @@ from lhotse import ( fix_manifests, validate_recordings_and_supervisions, ) + from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or diff --git a/egs/wenetspeech/ASR/local/prepare_pinyin.py b/egs/wenetspeech/ASR/local/prepare_pinyin.py index ae40f1cdd..112b50b79 100755 --- a/egs/wenetspeech/ASR/local/prepare_pinyin.py +++ b/egs/wenetspeech/ASR/local/prepare_pinyin.py @@ -41,6 +41,7 @@ from prepare_lang import ( write_lexicon, write_mapping, ) + from icefall.utils import text_to_pinyin diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py index ee8252a90..8c192913e 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py @@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging -from icefall import is_module_available +import torch from onnx_pretrained import OnnxModel -import torch +from icefall import is_module_available def get_parser(): diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 5ed3c6c2c..340a41231 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -30,9 +30,7 @@ import k2 import torch import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule -from beam_search import ( - keywords_search, -) +from beam_search import keywords_search from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index 6f34989e2..76df7e8d5 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -87,6 +87,19 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from train import ( + add_model_arguments, + add_training_arguments, + compute_validation_loss, + display_and_save_batch, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + scan_pessimistic_batches_for_oom, + set_batch_count, +) from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler @@ -109,21 +122,6 @@ from icefall.utils import ( text_to_pinyin, ) -from train import ( - add_model_arguments, - add_training_arguments, - compute_validation_loss, - display_and_save_batch, - get_adjusted_batch_count, - get_model, - get_params, - load_checkpoint_if_available, - save_checkpoint, - scan_pessimistic_batches_for_oom, - set_batch_count, -) - - LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 5be34ed99..05acbd6a9 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -99,7 +99,6 @@ from icefall.utils import ( text_to_pinyin, ) - LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index e29415ffb..6c643c263 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -18,9 +18,8 @@ you can use ./export.py --jit 1 import argparse import logging -from typing import List import math - +from typing import List import k2 import kaldifeat diff --git a/icefall/byte_utils.py b/icefall/byte_utils.py index 79c1c7545..5f5cc710b 100644 --- a/icefall/byte_utils.py +++ b/icefall/byte_utils.py @@ -8,7 +8,6 @@ import re import unicodedata - WHITESPACE_NORMALIZER = re.compile(r"\s+") SPACE = chr(32) SPACE_ESCAPE = chr(9601) diff --git a/icefall/ctc/prepare_lang.py b/icefall/ctc/prepare_lang.py index 4801b1beb..0e99e70d8 100644 --- a/icefall/ctc/prepare_lang.py +++ b/icefall/ctc/prepare_lang.py @@ -8,12 +8,12 @@ The lang_dir should contain the following files: """ import math +import re from collections import defaultdict from pathlib import Path from typing import List, Tuple import kaldifst -import re class Lexicon: diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 65b6f67b0..a3c480c9c 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -18,7 +18,7 @@ import random from dataclasses import dataclass -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import torch from torch import Tensor, nn diff --git a/icefall/profiler.py b/icefall/profiler.py index 49e138579..762105c48 100644 --- a/icefall/profiler.py +++ b/icefall/profiler.py @@ -5,14 +5,15 @@ # This is modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py +from collections import OrderedDict +from functools import partial +from typing import List, Optional + import k2 +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from functools import partial -from typing import List, Optional -from collections import OrderedDict -import numpy as np Tensor = torch.Tensor diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py index dfede708b..1070d443a 100755 --- a/icefall/rnn_lm/export-onnx.py +++ b/icefall/rnn_lm/export-onnx.py @@ -5,16 +5,16 @@ import argparse import logging from pathlib import Path +from typing import Dict import onnx import torch from model import RnnLmModel from onnxruntime.quantization import QuantType, quantize_dynamic +from train import get_params from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, str2bool -from typing import Dict -from train import get_params def add_meta_data(filename: str, meta_data: Dict[str, str]): diff --git a/icefall/utils.py b/icefall/utils.py index 7d722b1bc..31f9801d9 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -28,8 +28,6 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path -from pypinyin import pinyin, lazy_pinyin -from pypinyin.contrib.tone_convert import to_initials, to_finals_tone, to_finals from shutil import copyfile from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union @@ -40,6 +38,8 @@ import sentencepiece as spm import torch import torch.distributed as dist import torch.nn as nn +from pypinyin import lazy_pinyin, pinyin +from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import average_checkpoints diff --git a/requirements.txt b/requirements.txt index dec20c6e4..e64afd1ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,10 @@ pypinyin==0.50.0 tensorboard typeguard dill -black==22.3.0 onnx==1.15.0 onnxruntime==1.16.3 + +# style check session: +black==22.3.0 +isort==5.10.1 +flake8==5.0.4 \ No newline at end of file From ff430b465fcfe782551bc85105673740eaeeca83 Mon Sep 17 00:00:00 2001 From: Rezakh20 <160485045+Rezakh20@users.noreply.github.com> Date: Tue, 5 Mar 2024 12:10:30 +0330 Subject: [PATCH 115/216] Add num_features to train.py for training WSASR (#1528) --- egs/librispeech/WSASR/conformer_ctc2/train.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index fe6c5af91..daff40d59 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -31,6 +31,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir conformer_ctc2/exp \ --lang-dir data/lang_bpe_200 \ --otc-token "" \ + --feature-dim 768 \ --allow-bypass-arc true \ --allow-self-loop-arc true \ --initial-bypass-weight -19 \ @@ -159,6 +160,14 @@ def get_parser(): "lexicon.txt" """, ) + + parser.add_argument( + "--feature-dim", + type=int, + default=768, + help="""Number of features extracted in feature extraction stage.last dimension of feature vector. + 80 when using fbank features and 768 or 1024 whn using wave2vec""", + ) parser.add_argument( "--initial-lr", @@ -385,7 +394,6 @@ def get_params() -> AttributeDict: "valid_interval": 800, # For the 100h subset, use 800 "alignment_interval": 25, # parameters for conformer - "feature_dim": 768, "subsampling_factor": 2, "encoder_dim": 512, "nhead": 8, From 335a9962dece1f281731edb3b5f4dc9904e46328 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 6 Mar 2024 08:43:45 +0800 Subject: [PATCH 116/216] Fixed formatting issue of PR #1528 (#1530) --- egs/librispeech/WSASR/conformer_ctc2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index daff40d59..82c68803f 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -160,7 +160,7 @@ def get_parser(): "lexicon.txt" """, ) - + parser.add_argument( "--feature-dim", type=int, From cdb3fb5675b37b4ad2725e69bce609df019ffcd7 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 7 Mar 2024 18:47:29 +0800 Subject: [PATCH 117/216] add text norm script for pl (#1532) --- egs/commonvoice/ASR/local/preprocess_commonvoice.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index 5f6aa3ec0..c0f4ca427 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -48,8 +48,18 @@ def normalize_text(utt: str, language: str) -> str: utt = re.sub("’", "'", utt) if language == "en": return re.sub(r"[^a-zA-Z\s]", "", utt).upper() - if language == "fr": + elif language == "fr": return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper() + elif language == "pl": + return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper() + else: + raise NotImplementedError( + f""" + Text normalization not implemented for language: {language}, + please consider implementing it in the local/preprocess_commonvoice.py + or raise an issue on GitHub to request it. + """ + ) def preprocess_commonvoice( From 5df24c168519239b8c2f8327b666cae61d359a65 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 7 Mar 2024 18:04:27 +0700 Subject: [PATCH 118/216] Whisper large fine-tuning on wenetspeech, mutli-hans-zh (#1483) * add whisper fbank for wenetspeech * add whisper fbank for other dataset * add str to bool * add decode for wenetspeech * add requirments.txt * add original model decode with 30s * test feature extractor speed * add aishell2 feat * change compute feature batch * fix overwrite * fix executor * regression * add kaldifeatwhisper fbank * fix io issue * parallel jobs * use multi machines * add wenetspeech fine-tune scripts * add monkey patch codes * remove useless file * fix subsampling factor * fix too long audios * add remove long short * fix whisper version to support multi batch beam * decode all wav files * remove utterance more than 30s in test_net * only test net * using soft links * add kespeech whisper feats * fix index error * add manifests for whisper * change to licomchunky writer * add missing option * decrease cpu usage * add speed perturb for kespeech * fix kespeech speed perturb * add dataset * load checkpoint from specific path * add speechio * add speechio results --------- Co-authored-by: zr_jin --- egs/aishell/ASR/RESULTS.md | 2 +- egs/aishell2/ASR/RESULTS.md | 2 +- .../ASR/local/compute_fbank_aishell2.py | 36 +- egs/aishell2/ASR/prepare.sh | 10 + egs/aishell4/ASR/README.md | 2 +- .../ASR/local/compute_fbank_aishell4.py | 37 +- egs/aishell4/ASR/prepare.sh | 29 +- .../ASR/local/compute_fbank_alimeeting.py | 37 +- egs/alimeeting/ASR/prepare.sh | 30 +- egs/multi_zh-hans/ASR/README.md | 2 +- .../local/compute_fbank_kespeech_dev_test.py | 48 +- .../local/compute_fbank_kespeech_splits.py | 29 +- .../ASR/local/compute_fbank_magicdata.py | 56 +- .../ASR/local/compute_fbank_primewords.py | 33 +- .../ASR/local/compute_fbank_stcmds.py | 32 +- .../ASR/local/compute_fbank_thchs30.py | 32 +- egs/multi_zh-hans/ASR/prepare.sh | 171 ++- .../ASR/whisper/asr_datamodule.py | 1 + egs/multi_zh-hans/ASR/whisper/decode.py | 519 ++++++++ .../ASR/whisper/ds_config_zero1.json | 1 + .../ASR/whisper/label_smoothing.py | 1 + .../ASR/whisper/multi_dataset.py | 296 +++++ egs/multi_zh-hans/ASR/whisper/optim.py | 1 + .../ASR/whisper/requirements.txt | 1 + egs/multi_zh-hans/ASR/whisper/train.py | 983 ++++++++++++++ .../whisper_encoder_forward_monkey_patch.py | 1 + egs/speechio/ASR/README.md | 15 + egs/speechio/ASR/RESULTS.md | 92 ++ .../ASR/local/compute_fbank_speechio.py | 148 +++ .../ASR/local/display_manifest_statistics.py | 1162 +++++++++++++++++ .../ASR/local/whisper_zipformer_fusion.py | 217 +++ egs/speechio/ASR/prepare.sh | 67 + egs/speechio/ASR/shared | 1 + egs/speechio/ASR/whisper/asr_datamodule.py | 195 +++ egs/speechio/ASR/whisper/decode.py | 520 ++++++++ egs/speechio/ASR/whisper/multi_dataset.py | 59 + egs/speechio/ASR/whisper/requirements.txt | 1 + .../whisper_encoder_forward_monkey_patch.py | 1 + egs/speechio/ASR/zipformer/asr_datamodule.py | 1 + egs/speechio/ASR/zipformer/beam_search.py | 1 + egs/speechio/ASR/zipformer/ctc_decode.py | 623 +++++++++ egs/speechio/ASR/zipformer/decode.py | 843 ++++++++++++ egs/speechio/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + egs/speechio/ASR/zipformer/joiner.py | 1 + egs/speechio/ASR/zipformer/model.py | 1 + egs/speechio/ASR/zipformer/multi_dataset.py | 1 + egs/speechio/ASR/zipformer/optim.py | 1 + egs/speechio/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + egs/speechio/ASR/zipformer/subsampling.py | 1 + egs/speechio/ASR/zipformer/train.py | 1 + egs/speechio/ASR/zipformer/zipformer.py | 1 + .../compute_fbank_wenetspeech_dev_test.py | 48 +- .../local/compute_fbank_wenetspeech_splits.py | 42 +- egs/wenetspeech/ASR/prepare.sh | 47 +- egs/wenetspeech/ASR/whisper/asr_datamodule.py | 1 + egs/wenetspeech/ASR/whisper/decode.py | 526 ++++++++ .../ASR/whisper/ds_config_zero1.json | 1 + .../ASR/whisper/label_smoothing.py | 1 + egs/wenetspeech/ASR/whisper/optim.py | 1 + egs/wenetspeech/ASR/whisper/requirements.txt | 1 + egs/wenetspeech/ASR/whisper/train.py | 955 ++++++++++++++ .../whisper_encoder_forward_monkey_patch.py | 1 + 64 files changed, 7844 insertions(+), 129 deletions(-) create mode 120000 egs/multi_zh-hans/ASR/whisper/asr_datamodule.py create mode 100644 egs/multi_zh-hans/ASR/whisper/decode.py create mode 120000 egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json create mode 120000 egs/multi_zh-hans/ASR/whisper/label_smoothing.py create mode 100644 egs/multi_zh-hans/ASR/whisper/multi_dataset.py create mode 120000 egs/multi_zh-hans/ASR/whisper/optim.py create mode 120000 egs/multi_zh-hans/ASR/whisper/requirements.txt create mode 100644 egs/multi_zh-hans/ASR/whisper/train.py create mode 120000 egs/multi_zh-hans/ASR/whisper/whisper_encoder_forward_monkey_patch.py create mode 100644 egs/speechio/ASR/README.md create mode 100644 egs/speechio/ASR/RESULTS.md create mode 100644 egs/speechio/ASR/local/compute_fbank_speechio.py create mode 100644 egs/speechio/ASR/local/display_manifest_statistics.py create mode 100644 egs/speechio/ASR/local/whisper_zipformer_fusion.py create mode 100644 egs/speechio/ASR/prepare.sh create mode 120000 egs/speechio/ASR/shared create mode 100644 egs/speechio/ASR/whisper/asr_datamodule.py create mode 100644 egs/speechio/ASR/whisper/decode.py create mode 100644 egs/speechio/ASR/whisper/multi_dataset.py create mode 120000 egs/speechio/ASR/whisper/requirements.txt create mode 120000 egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py create mode 120000 egs/speechio/ASR/zipformer/asr_datamodule.py create mode 120000 egs/speechio/ASR/zipformer/beam_search.py create mode 100644 egs/speechio/ASR/zipformer/ctc_decode.py create mode 100644 egs/speechio/ASR/zipformer/decode.py create mode 120000 egs/speechio/ASR/zipformer/decoder.py create mode 120000 egs/speechio/ASR/zipformer/encoder_interface.py create mode 120000 egs/speechio/ASR/zipformer/joiner.py create mode 120000 egs/speechio/ASR/zipformer/model.py create mode 120000 egs/speechio/ASR/zipformer/multi_dataset.py create mode 120000 egs/speechio/ASR/zipformer/optim.py create mode 120000 egs/speechio/ASR/zipformer/scaling.py create mode 120000 egs/speechio/ASR/zipformer/scaling_converter.py create mode 120000 egs/speechio/ASR/zipformer/subsampling.py create mode 120000 egs/speechio/ASR/zipformer/train.py create mode 120000 egs/speechio/ASR/zipformer/zipformer.py create mode 120000 egs/wenetspeech/ASR/whisper/asr_datamodule.py create mode 100755 egs/wenetspeech/ASR/whisper/decode.py create mode 120000 egs/wenetspeech/ASR/whisper/ds_config_zero1.json create mode 120000 egs/wenetspeech/ASR/whisper/label_smoothing.py create mode 120000 egs/wenetspeech/ASR/whisper/optim.py create mode 120000 egs/wenetspeech/ASR/whisper/requirements.txt create mode 100644 egs/wenetspeech/ASR/whisper/train.py create mode 120000 egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 46d712fb2..355d1516d 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc | fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 | ```bash -./prepare.sh +./prepare.sh export CUDA_VISIBLE_DEVICES="0,1" diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md index 32ad74b50..0b7ae9299 100644 --- a/egs/aishell2/ASR/RESULTS.md +++ b/egs/aishell2/ASR/RESULTS.md @@ -1,6 +1,6 @@ ## Results -### Aishell2 char-based training results +### Aishell2 char-based training results #### Pruned transducer stateless 5 diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index 1fb1621ff..557f22b0c 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,10 +49,12 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False): +def compute_fbank_aishell2( + num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests") output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) + num_jobs = min(8, os.cpu_count()) dataset_parts = ( "train", @@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False): list(manifests.keys()), dataset_parts, ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False): supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -111,7 +124,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) - + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser.parse_args() @@ -122,5 +140,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell2( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index a5eb9bd13..c959bd4d1 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi fi +whisper_mel_bins=80 +if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then + log "Stage 30: Compute whisper fbank for aishell2" + if [ ! -f data/fbank/.aishell2.whisper.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.aishell2.whisper.done + fi +fi + if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for musan" if [ ! -f data/fbank/.msuan.done ]; then diff --git a/egs/aishell4/ASR/README.md b/egs/aishell4/ASR/README.md index 67fa17790..b96161762 100644 --- a/egs/aishell4/ASR/README.md +++ b/egs/aishell4/ASR/README.md @@ -3,7 +3,7 @@ This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets). -The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. +The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. (From [Open Speech and Language Resources](https://www.openslr.org/111/)) diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index f19163988..b5f8468ac 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,10 +49,12 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): +def compute_fbank_aishell4( + num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/aishell4") output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) + num_jobs = min(8, os.cpu_count()) dataset_parts = ( "train_S", @@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): # when an executor is specified, make more partitions num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=ChunkedLilcomHdf5Writer, + storage_type=LilcomChunkyWriter, ) logging.info("About splitting cuts into smaller chunks") @@ -121,7 +135,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) - + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser.parse_args() @@ -132,5 +151,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell4( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index e8d9eb7b9..38a36d97a 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail stage=-1 -stop_stage=100 +stop_stage=7 perturb_speed=true @@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Process aishell4" + log "Stage 2: Compute fbank for aishell4" if [ ! -f data/fbank/aishell4/.fbank.done ]; then - mkdir -p data/fbank/aishell4 + mkdir -p data/fbank ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} - touch data/fbank/aishell4/.fbank.done + touch data/fbank/.fbank.done + fi +fi + +whisper_mel_bins=80 +if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then + log "Stage 20: Compute whisper fbank for aishell4" + if [ ! -f data/fbank/aishell4/.fbank.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.fbank.done fi fi @@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank for aishell4" - if [ ! -f data/fbank/.aishell4.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} - touch data/fbank/.aishell4.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare char based lang" + log "Stage 5: Prepare char based lang" lang_char_dir=data/lang_char mkdir -p $lang_char_dir diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index f8c10648a..09c873a34 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,10 +49,12 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False): +def compute_fbank_alimeeting( + num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/alimeeting") output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) + num_jobs = min(8, os.cpu_count()) dataset_parts = ( "train", @@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False "test", ) - prefix = "alimeeting" + prefix = "alimeeting-far" suffix = "jsonl.gz" manifests = read_manifests_if_cached( dataset_parts=dataset_parts, @@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -121,7 +135,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) - + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use the Whisper Fbank feature extractor. Default: False.", + ) return parser.parse_args() @@ -132,5 +151,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_alimeeting( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index c8fed658d..301ab0111 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail stage=-1 -stop_stage=100 +stop_stage=7 perturb_speed=true # We assume dl_dir (download dir) contains the following @@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Process alimeeting" - if [ ! -f data/fbank/alimeeting/.fbank.done ]; then - mkdir -p data/fbank/alimeeting + log "Stage 2: compute fbank for alimeeting" + if [ ! -f data/fbank/.fbank.done ]; then + mkdir -p data/fbank ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} + touch data/fbank/.fbank.done + fi +fi + +whisper_mel_bins=80 +if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then + log "Stage 20: compute whisper fbank for alimeeting" + if [ ! -f data/fbank/.fbank.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.fbank.done fi fi @@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank for alimeeting" - if [ ! -f data/fbank/.alimeeting.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_alimeeting.py --perturb-speed True - touch data/fbank/.alimeeting.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare char based lang" + log "Stage 5: Prepare char based lang" lang_char_dir=data/lang_char mkdir -p $lang_char_dir diff --git a/egs/multi_zh-hans/ASR/README.md b/egs/multi_zh-hans/ASR/README.md index 537816a5d..1e60c733c 100644 --- a/egs/multi_zh-hans/ASR/README.md +++ b/egs/multi_zh-hans/ASR/README.md @@ -36,4 +36,4 @@ This recipe includes scripts for training Zipformer model using multiple Chinese 3. AliMeeting 4. MagicData 5. KeSpeech-ASR -6. WeNetSpeech \ No newline at end of file +6. WeNetSpeech diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py index 2581ee42f..6f75dbfa4 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py @@ -17,11 +17,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging from pathlib import Path import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) + +from icefall.utils import str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -31,7 +41,28 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_kespeech_dev_test(): +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser + + +def compute_fbank_kespeech_dev_test(args): in_out_dir = Path("data/fbank/kespeech") # number of workers in dataloader num_workers = 42 @@ -48,7 +79,12 @@ def compute_fbank_kespeech_dev_test(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + if args.whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) + ) + else: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") @@ -86,7 +122,11 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_kespeech_dev_test() + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + compute_fbank_kespeech_dev_test(args) if __name__ == "__main__": diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py index 8bfbc7b50..c398411f6 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py @@ -28,10 +28,14 @@ from lhotse import ( KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, set_audio_duration_mismatch_tolerance, set_caching_enabled, ) +from icefall.utils import str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -88,6 +92,20 @@ def get_parser(): default=-1, help="Stop processing pieces until this number (exclusive).", ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser @@ -111,14 +129,19 @@ def compute_fbank_kespeech_splits(args): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + if args.whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) + else: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance set_caching_enabled(False) for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) - logging.info(f"Processing {idx}/{num_splits}") + idx = f"{i}".zfill(num_digits) + logging.info(f"Processing {i+1}/{num_splits}") cuts_path = output_dir / f"kespeech-asr_cuts_{subset}.{idx}.jsonl.gz" if cuts_path.is_file(): diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py index 5649d3815..192bffa9f 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py @@ -30,10 +30,17 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -43,10 +50,33 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False): +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser + + +def compute_fbank_magicdata( + num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/magicdata") output_dir = Path("data/fbank") - num_jobs = min(30, os.cpu_count()) + num_jobs = min(8, os.cpu_count()) dataset_parts = ("train", "test", "dev") prefix = "magicdata" @@ -66,7 +96,12 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False) dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if args.whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -107,7 +142,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) - + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser.parse_args() @@ -118,5 +158,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_magicdata( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + num_mel_bins=args.num_mel_bins, + speed_perturb=args.speed_perturb, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py index 303a16580..019b10d24 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py @@ -30,10 +30,17 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -43,7 +50,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False): +def compute_fbank_primewords( + num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/primewords") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -66,7 +75,12 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -108,6 +122,13 @@ def get_args(): help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser.parse_args() @@ -118,5 +139,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_primewords( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + num_mel_bins=args.num_mel_bins, + speed_perturb=args.speed_perturb, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py index 730806954..f29ae5a46 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py @@ -30,10 +30,17 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -43,7 +50,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False): +def compute_fbank_stcmds( + num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/stcmds") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -66,7 +75,12 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False): dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -107,6 +121,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser.parse_args() @@ -117,5 +137,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_stcmds( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + num_mel_bins=args.num_mel_bins, + speed_perturb=args.speed_perturb, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py index 58bb8002a..4ad78e0ba 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py @@ -30,10 +30,17 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -43,7 +50,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): +def compute_fbank_thchs30( + num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/thchs30") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -70,7 +79,12 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -113,6 +127,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser.parse_args() @@ -123,5 +143,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_thchs30( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + num_mel_bins=args.num_mel_bins, + speed_perturb=args.speed_perturb, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index c09b9c1de..fa515ed50 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -60,7 +60,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ ! -f data/fbank/.thchs30.done ]; then mkdir -p data/fbank - ./local/compute_fbank_thchs30.py + ./local/compute_fbank_thchs30.py --speed-perturb true touch data/fbank/.thchs30.done fi fi @@ -86,7 +86,7 @@ fi log "Dataset: AISHELL-2" if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Prepare AISHELL-2" - if [ -e ../../aishell/ASR/data/fbank/.aishell2.done ]; then + if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then cd data/fbank ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_train) . ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_dev) . @@ -95,30 +95,30 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_dev.jsonl.gz) . ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_test.jsonl.gz) . cd ../.. - else + else log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" exit 1 - fi + fi fi log "Dataset: AISHELL-4" if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare AISHELL-4" - if [ -e ../../aishell/ASR/data/fbank/.aishell4.done ]; then + if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then cd data/fbank - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_L) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_M) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_S) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) . cd ../.. - else + else log "Abort! Please run ../../aishell4/ASR/prepare.sh --stage 3 --stop-stage 3" exit 1 - fi + fi fi log "Dataset: ST-CMDS" @@ -137,7 +137,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ ! -f data/fbank/.stcmds.done ]; then mkdir -p data/fbank - ./local/compute_fbank_stcmds.py + ./local/compute_fbank_stcmds.py --speed-perturb true touch data/fbank/.stcmds.done fi fi @@ -151,15 +151,15 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then lhotse download primewords $dl_dir/primewords fi - if [ ! -f data/manifests/.stcmds.done ]; then + if [ ! -f data/manifests/.primewords.done ]; then mkdir -p data/manifests - lhotse prepare stcmds $dl_dir/primewords data/manifests/primewords + lhotse prepare primewords $dl_dir/primewords data/manifests/primewords touch data/manifests/.primewords.done fi if [ ! -f data/fbank/.primewords.done ]; then mkdir -p data/fbank - ./local/compute_fbank_primewords.py + ./local/compute_fbank_primewords.py --speed-perturb true touch data/fbank/.primewords.done fi fi @@ -180,7 +180,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then if [ ! -f data/fbank/.magicdata.done ]; then mkdir -p data/fbank - ./local/compute_fbank_magicdata.py + ./local/compute_fbank_magicdata.py --speed-perturb true touch data/fbank/.magicdata.done fi fi @@ -231,7 +231,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_1000) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_${num_splits}) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/*.lca) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/) ./wenetspeech cd ../.. @@ -261,7 +261,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then if [ ! -f data/manifests/.kespeech.done ]; then mkdir -p data/manifests - lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech + lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech touch data/manifests/.kespeech.done fi @@ -272,29 +272,29 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then python3 ./local/preprocess_kespeech.py touch data/fbank/.kespeech_preprocess_complete - fi - - if [ -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then + fi + + if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then log "Spliting KeSpeech train_phase1" lhotse split ${num_splits} \ data/fbank/kespeech/kespeech-asr_cuts_train_phase1_raw.jsonl.gz \ data/fbank/kespeech/train_phase1_split_${num_splits} touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done fi - - if [ -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then + + if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then log "Spliting KeSpeech train_phase2" lhotse split ${num_splits} \ data/fbank/kespeech/kespeech-asr_cuts_train_phase2_raw.jsonl.gz \ data/fbank/kespeech/train_phase2_split_${num_splits} touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done fi - + log "Compute KeSpeech fbank for train_phase1" - ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 + ./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase1 log "Compute KeSpeech fbank for train_phase2" - ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase2 + ./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase2 log "Compute KeSpeech fbank for test/dev" ./local/compute_fbank_kespeech_dev_test.py @@ -303,13 +303,126 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then fi fi +whisper_mel_bins=80 +if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then + log "Stage 120: Prepare KeSpeech for whisper" + if [ ! -d $dl_dir/KeSpeech ]; then + log "Abort! Please download KeSpeech first." + log "KeSpeech download link: https://github.com/KeSpeech/KeSpeech" + exit 1 + fi + + if [ ! -f data/manifests/.kespeech.done ]; then + mkdir -p data/manifests + lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech + touch data/manifests/.kespeech.done + fi + + if [ ! -f data/fbank/.kespeech.done ]; then + mkdir -p data/fbank + + log "Preprocess KeSpeech manifest" + if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then + python3 ./local/preprocess_kespeech.py --speed-perturb true + touch data/fbank/.kespeech_preprocess_complete + fi + + if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then + log "Spliting KeSpeech train_phase1" + lhotse split ${num_splits} \ + data/fbank/kespeech/kespeech-asr_cuts_train_phase1_raw.jsonl.gz \ + data/fbank/kespeech/train_phase1_split_${num_splits} + touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done + fi + + if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then + log "Spliting KeSpeech train_phase2" + lhotse split ${num_splits} \ + data/fbank/kespeech/kespeech-asr_cuts_train_phase2_raw.jsonl.gz \ + data/fbank/kespeech/train_phase2_split_${num_splits} + touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done + fi + + log "Compute KeSpeech fbank for train_phase1" + ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + + log "Compute KeSpeech fbank for train_phase2" + ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase2 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + + log "Compute KeSpeech fbank for test/dev" + # ./local/compute_fbank_kespeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + + if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then + pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz") + lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz + fi + if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then + pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz") + lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz + fi + touch data/fbank/.kespeech.done + fi +fi + +if [ $stage -le 121 ] && [ $stop_stage -ge 121 ]; then + log "Stage 121: Prepare MagicData, Primewords, ST-CMDS, THCHS-30 for whisper" + + if [ ! -f data/manifests/.magicdata.done ]; then + mkdir -p data/manifests + lhotse prepare magicdata $dl_dir/magicdata data/manifests/magicdata + touch data/manifests/.magicdata.done + fi + + if [ ! -f data/manifests/.primewords.done ]; then + mkdir -p data/manifests + lhotse prepare primewords $dl_dir/primewords data/manifests/primewords + touch data/manifests/.primewords.done + fi + if [ ! -f data/manifests/.stcmds.done ]; then + mkdir -p data/manifests + lhotse prepare stcmds $dl_dir/stcmds data/manifests/stcmds + touch data/manifests/.stcmds.done + fi + + if [ ! -f data/manifests/.thchs30.done ]; then + mkdir -p data/manifests + lhotse prepare thchs-30 $dl_dir/thchs30 data/manifests/thchs30 + touch data/manifests/.thchs30.done + fi + + if [ ! -f data/fbank/.thchs30.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_thchs30.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.thchs30.done + fi + + if [ ! -f data/fbank/.stcmds.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_stcmds.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.stcmds.done + fi + if [ ! -f data/fbank/.magicdata.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_magicdata.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.magicdata.done + fi + + if [ ! -f data/fbank/.primewords.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_primewords.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.primewords.done + fi + +fi + + if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then log "Stage 13: BPE model training (note that we use transcripts of wenetspeech only for BPE training)" ./local/prepare_for_bpe_model.py --lang-dir ./data/lang_char --text ./data/lang_char/text for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} - + mkdir -p $lang_dir if [ ! -f $lang_dir/bpe.model ]; then ./local/train_bpe_model.py \ @@ -329,7 +442,7 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then --lexicon $lang_dir/lexicon.txt \ --bpe-model $lang_dir/bpe.model fi - + if [ ! -f $lang_dir/L.fst ]; then log "Converting L.pt to L.fst" ./shared/convert-k2-to-openfst.py \ @@ -350,7 +463,7 @@ fi if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then log "Stage 14: Prepare G (note that we use ngram lm of wenetspeech only for G preparation)" - + if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then cd data ln -s ../../../../wenetspeech/ASR/data/lm . @@ -369,5 +482,3 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then python ./local/compile_lg.py --lang-dir $lang_dir done fi - - diff --git a/egs/multi_zh-hans/ASR/whisper/asr_datamodule.py b/egs/multi_zh-hans/ASR/whisper/asr_datamodule.py new file mode 120000 index 000000000..3c8b7f2d4 --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/asr_datamodule.py @@ -0,0 +1 @@ +../zipformer/asr_datamodule.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/decode.py b/egs/multi_zh-hans/ASR/whisper/decode.py new file mode 100644 index 000000000..aabb80eaf --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/decode.py @@ -0,0 +1,519 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Wei Kang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# Command for decoding using fine-tuned models: +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --beam-size 10 --max-duration 50 + +# Command for decoding using pretrained models (before fine-tuning): + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 10 --max-duration 50 + +""" + +import argparse +import logging +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +import whisper +from asr_datamodule import AsrDataModule +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from tn.chinese.normalizer import Normalizer +from whisper.normalizers import BasicTextNormalizer +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward +from zhconv import convert + +from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def average_checkpoints( + filenames: List[Path], device: torch.device = torch.device("cpu") +) -> dict: + """Average a list of checkpoints. + The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. + + Args: + filenames: + Filenames of the checkpoints to be averaged. We assume all + checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. + Returns: + Return a dict (i.e., state_dict) which is the average of all + model state dicts contained in the checkpoints. + """ + n = len(filenames) + + if "model" in torch.load(filenames[0], map_location=device): + avg = torch.load(filenames[0], map_location=device)["model"] + else: + avg = torch.load(filenames[0], map_location=device) + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + + for i in range(1, n): + if "model" in torch.load(filenames[i], map_location=device): + state_dict = torch.load(filenames[i], map_location=device)["model"] + else: + state_dict = torch.load(filenames[i], map_location=device) + for k in uniqued_names: + avg[k] += state_dict[k] + + for k in uniqued_names: + if avg[k].is_floating_point(): + avg[k] /= n + else: + avg[k] //= n + + return avg + + +def remove_punctuation(text: str or List[str]): + """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py + + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings without any punctuation. + """ + punctuation = "!,.;:?、!,。;:?《》 " + if isinstance(text, str): + text = re.sub(r"[{}]+".format(punctuation), "", text).strip() + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = re.sub(r"[{}]+".format(punctuation), "", t).strip() + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type {type(text)}") + + +def to_simple(text: str or List[str]): + """Convert traditional Chinese to simplified Chinese. + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings converted to simplified Chinese. + """ + if isinstance(text, str): + text = convert(text, "zh-cn") + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = convert(t, "zh-cn") + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type{type(text)}") + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=-1, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="beam-search", + help="""Decoding method. + Supported values are: + - beam-search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=1, + help="beam size for beam search decoding", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--remove-whisper-encoder-input-length-restriction", + type=str2bool, + default=True, + help="replace whisper encoder forward method to remove input length restriction", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: "beam-search" + - value: A list of lists. Each sublist is a list of token IDs. + Args: + params: + It is returned by :func:`get_params`. + model: + The neural model. + batch: + It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. + Returns: + Return a dict, whose key may be "beam-search". + """ + dtype = torch.float16 + device = torch.device("cuda") + + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device, dtype=dtype).transpose(1, 2) + if not params.remove_whisper_encoder_input_length_restriction: + T = 3000 + if feature.shape[2] < T: + feature = torch.cat( + [ + feature, + torch.zeros( + feature.shape[0], feature.shape[1], T - feature.shape[2] + ).to(device, dtype=dtype), + ], + 2, + ) + + supervisions = batch["supervisions"] + feature_len = supervisions["num_frames"] + feature_len = feature_len.to(device, dtype=dtype) + results = model.decode(feature, params.decoding_options) + hyps = [result.text for result in results] + + hyps = remove_punctuation(hyps) + hyps = to_simple(hyps) + hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + print(hyps) + return {"beam-search": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + The dataloader. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "beam-search". + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + batch=batch, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + setup_logger( + f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + ) + + options = whisper.DecodingOptions( + task="transcribe", + language="zh", + without_timestamps=True, + beam_size=params.beam_size, + ) + params.decoding_options = options + params.cleaner = BasicTextNormalizer() + params.normalizer = Normalizer() + + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + + logging.info(f"device: {device}") + + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + if params.epoch > 0: + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + model.load_state_dict(average_checkpoints(filenames)) + else: + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + # save checkpoints + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(model.state_dict(), filename) + else: + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + def remove_long_utt(c: Cut): + # Keep only utterances with duration in 30 seconds + # + if c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dls = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json b/egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json new file mode 120000 index 000000000..af7162d6c --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/ds_config_zero1.json \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/label_smoothing.py b/egs/multi_zh-hans/ASR/whisper/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/multi_dataset.py b/egs/multi_zh-hans/ASR/whisper/multi_dataset.py new file mode 100644 index 000000000..b562e626b --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/multi_dataset.py @@ -0,0 +1,296 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import logging +import re +from pathlib import Path +from typing import Dict, List + +import lhotse +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, fbank_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - aishell_cuts_train.jsonl.gz + - aishell2_cuts_train.jsonl.gz + - aishell4_cuts_train_L.jsonl.gz + - aishell4_cuts_train_M.jsonl.gz + - aishell4_cuts_train_S.jsonl.gz + - alimeeting-far_cuts_train.jsonl.gz + - magicdata_cuts_train.jsonl.gz + - primewords_cuts_train.jsonl.gz + - stcmds_cuts_train.jsonl.gz + - thchs_30_cuts_train.jsonl.gz + - kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz + - kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz + - wenetspeech/cuts_L.jsonl.gz + """ + self.fbank_dir = Path(fbank_dir) + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # THCHS-30 + logging.info("Loading THCHS-30 in lazy mode") + thchs_30_cuts = load_manifest_lazy( + self.fbank_dir / "thchs_30_cuts_train.jsonl.gz" + ) + + # AISHELL-1 + logging.info("Loading Aishell-1 in lazy mode") + aishell_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_train.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) + + # AISHELL-4 + logging.info("Loading Aishell-4 in lazy mode") + aishell_4_L_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz" + ) + aishell_4_M_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz" + ) + aishell_4_S_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz" + ) + + # ST-CMDS + logging.info("Loading ST-CMDS in lazy mode") + stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz") + + # Primewords + logging.info("Loading Primewords in lazy mode") + primewords_cuts = load_manifest_lazy( + self.fbank_dir / "primewords_cuts_train.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData in lazy mode") + magicdata_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_train.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting in lazy mode") + alimeeting_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech in lazy mode") + wenetspeech_L_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech in lazy mode") + kespeech_1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz" + ) + kespeech_2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz" + ) + + return CutSet.mux( + thchs_30_cuts, + aishell_cuts, + aishell_2_cuts, + aishell_4_L_cuts, + aishell_4_M_cuts, + aishell_4_S_cuts, + stcmds_cuts, + primewords_cuts, + magicdata_cuts, + alimeeting_cuts, + wenetspeech_L_cuts, + kespeech_1_cuts, + kespeech_2_cuts, + weights=[ + len(thchs_30_cuts), + len(aishell_cuts), + len(aishell_2_cuts), + len(aishell_4_L_cuts), + len(aishell_4_M_cuts), + len(aishell_4_S_cuts), + len(stcmds_cuts), + len(primewords_cuts), + len(magicdata_cuts), + len(alimeeting_cuts), + len(wenetspeech_L_cuts), + len(kespeech_1_cuts), + len(kespeech_2_cuts), + ], + ) + + def dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + # AISHELL + logging.info("Loading Aishell DEV set in lazy mode") + aishell_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_dev.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 DEV set in lazy mode") + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting DEV set in lazy mode") + alimeeting_dev_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData DEV set in lazy mode") + magicdata_dev_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech DEV set in lazy mode") + kespeech_dev_phase1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" + ) + kespeech_dev_phase2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech DEV set in lazy mode") + wenetspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" + ) + + return wenetspeech_dev_cuts + # return [ + # aishell_dev_cuts, + # aishell2_dev_cuts, + # alimeeting_dev_cuts, + # magicdata_dev_cuts, + # kespeech_dev_phase1_cuts, + # kespeech_dev_phase2_cuts, + # wenetspeech_dev_cuts, + # ] + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + # AISHELL + logging.info("Loading Aishell set in lazy mode") + aishell_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_test.jsonl.gz" + ) + aishell_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_dev.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # AISHELL-4 + logging.info("Loading Aishell-4 TEST set in lazy mode") + aishell4_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_test.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting set in lazy mode") + alimeeting_test_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz" + ) + alimeeting_eval_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData set in lazy mode") + magicdata_test_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_test.jsonl.gz" + ) + magicdata_dev_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech set in lazy mode") + kespeech_test_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz" + ) + kespeech_dev_phase1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" + ) + kespeech_dev_phase2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech set in lazy mode") + wenetspeech_test_meeting_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" + ) + wenetspeech_test_net_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz" + ) + wenetspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" + ) + + return { + "aishell-2_test": aishell2_test_cuts, + "aishell-4": aishell4_test_cuts, + "magicdata_test": magicdata_test_cuts, + "kespeech-asr_test": kespeech_test_cuts, + } + + # return { + # "alimeeting_test": alimeeting_test_cuts, + # "alimeeting_eval": alimeeting_eval_cuts, + # "aishell_test": aishell_test_cuts, + # "aishell_dev": aishell_dev_cuts, + # "aishell-2_test": aishell2_test_cuts, + # "aishell-2_dev": aishell2_dev_cuts, + # "aishell-4": aishell4_test_cuts, + # "magicdata_test": magicdata_test_cuts, + # "magicdata_dev": magicdata_dev_cuts, + # "kespeech-asr_test": kespeech_test_cuts, + # "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, + # "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, + # "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, + # "wenetspeech-net_test": wenetspeech_test_net_cuts, + # "wenetspeech_dev": wenetspeech_dev_cuts, + # } diff --git a/egs/multi_zh-hans/ASR/whisper/optim.py b/egs/multi_zh-hans/ASR/whisper/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/requirements.txt b/egs/multi_zh-hans/ASR/whisper/requirements.txt new file mode 120000 index 000000000..744bf8bb6 --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/requirements.txt @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/requirements.txt \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py new file mode 100644 index 000000000..11a22eec1 --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -0,0 +1,983 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +#fine-tuning with deepspeed zero stage 1 +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --deepspeed \ + --deepspeed_config ./whisper/ds_config_zero1.json + +# fine-tuning with ddp +torchrun --nproc_per_node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_medium \ + --base-lr 1e-5 \ + --model-name medium +""" + +import argparse +import copy +import logging +import os +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import deepspeed +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import whisper +from asr_datamodule import AsrDataModule +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from label_smoothing import LabelSmoothingLoss +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from multi_dataset import MultiDataset +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.functional import pad as pad_tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import update_averaged_model +from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--pretrained-model-path", + type=str, + default=None, + help="""The path to the pretrained model if it is not None. Training will + start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=1e-5, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser = deepspeed.add_config_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - frame_shift_ms: The frame shift in milliseconds. + - allowed_excess_duration_ratio: The allowed excess duration ratio. + - best_train_loss: The best training loss so far. + - best_valid_loss: The best validation loss so far. + - best_train_epoch: The epoch where the best training loss is achieved. + - best_valid_epoch: The epoch where the best validation loss is achieved. + - batch_idx_train: The batch index of the current batch. + - log_interval: Log training stats every `log_interval` batches. + - reset_interval: Reset the stats every `reset_interval` batches. + - valid_interval: Run validation every `valid_interval` batches. + - env_info: The environment information. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "subsampling_factor": 2, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute the loss for the given batch. + Args: + params: + It is returned by :func:`get_params`. + tokenizer: + The tokenizer used to encode the text. + model: + The model for training. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + Whether it is training. + Returns: + Return a tuple of two elements. The first element is the loss tensor. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: + padding_size = max(tensor.shape[0] for tensor in tensors) + dims = len(tensors[0].shape) + padded_tensors = [] + for tensor in tensors: + padding = [0] * 2 * dims + padding[-1] = padding_size - tensor.shape[0] + padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) + return torch.stack([tensor for tensor in padded_tensors], dim=0) + + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + + assert feature.ndim == 3 + feature = feature.to(device) + feature = feature.transpose(1, 2) # (N, C, T) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + # remove spaces in texts + texts = [text.replace(" ", "") for text in texts] + + text_tokens_list = [ + list(tokenizer.sot_sequence_including_notimestamps) + + tokenizer.encode(text) + + [tokenizer.eot] + for text in texts + ] + # convert it to torch tensor + text_tokens_list = [ + torch.LongTensor(text_tokens) for text_tokens in text_tokens_list + ] + + # 50256 is the index of for all whisper models + prev_outputs_tokens = _batch_tensors( + [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 + ) + target_tokens = _batch_tensors( + [tokens[1:] for tokens in text_tokens_list], pad_value=50256 + ) + target_lengths = torch.LongTensor( + [tokens.shape[0] - 1 for tokens in text_tokens_list] + ) + + decoder_criterion = LabelSmoothingLoss( + ignore_index=50256, label_smoothing=0.1, reduction="sum" + ) + + # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> + ignore_prefix_size = 3 + with torch.set_grad_enabled(is_training): + encoder_out = model.encoder(feature) + text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) + text_logits = text_logits[:, ignore_prefix_size:, :] + target_tokens = target_tokens[:, ignore_prefix_size:] + loss = decoder_criterion(text_logits, target_tokens.to(device)) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + tokenizer=tokenizer, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + if params.deepspeed: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + client_state={}, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + ) + os.system( + f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" + ) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + if params.deepspeed: + # deepspeed's backward() is different from torch's backward() + # in that it does not accept a loss tensor as input. + # It computes the loss internally. + model.backward(loss) + model.step() + else: + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + and not params.deepspeed + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + if batch_idx % params.log_interval == 0: + try: + cur_lr = scheduler.get_last_lr()[0] + except: # noqa + cur_lr = 0.0 + cur_grad_scale = ( + scaler._scale.item() + if (params.use_fp16 and not params.deepspeed) + else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if (params.use_fp16 and not params.deepspeed) + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info(params) + + logging.info("About to create model") + + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + del model.alignment_heads + + if params.pretrained_model_path: + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(params.pretrained_model_path, model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + tokenizer = whisper.tokenizer.get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language="zh", + task="transcribe", + ) + + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + else: + device = torch.device("cpu") + logging.info(f"Device: {device}") + model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if world_size > 1: + if params.deepspeed: + logging.info("Using DeepSpeed") + model, optimizer, _, scheduler = deepspeed.initialize( + args=params, model=model, model_parameters=model.parameters() + ) + else: + logging.info("Using DDP") + setup_dist(use_ddp_launch=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = multi_dataset.train_cuts() + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = data_module.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = multi_dataset.dev_cuts() + valid_dl = data_module.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + logging.info(f"start training from epoch {params.start_epoch}") + for epoch in range(params.start_epoch, params.num_epochs + 1): + if not params.deepspeed: + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + tokenizer=tokenizer, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.deepspeed: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}", + client_state={}, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", + tag=f"epoch-{params.cur_epoch}", + ) + os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") + else: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1 and not params.deepspeed: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = get_world_size() + rank = get_rank() + + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + run(rank=rank, world_size=world_size, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/multi_zh-hans/ASR/whisper/whisper_encoder_forward_monkey_patch.py new file mode 120000 index 000000000..2a7808921 --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/whisper_encoder_forward_monkey_patch.py @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py \ No newline at end of file diff --git a/egs/speechio/ASR/README.md b/egs/speechio/ASR/README.md new file mode 100644 index 000000000..2675efd9b --- /dev/null +++ b/egs/speechio/ASR/README.md @@ -0,0 +1,15 @@ + +# Introduction + +This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets. + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Pretrained Models + +The following table lists the pretrained models. + +| | Huggingface | Comment | +|---------------------------------------|--------------------|-----------------------------| +| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | | +| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training | diff --git a/egs/speechio/ASR/RESULTS.md b/egs/speechio/ASR/RESULTS.md new file mode 100644 index 000000000..07649e383 --- /dev/null +++ b/egs/speechio/ASR/RESULTS.md @@ -0,0 +1,92 @@ +## Results + +### SpeechIO Test Set Decoding Results + +##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper). + +| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion | +|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------| +| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 | +| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 | +| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 | +| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 | +| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 | +| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 | +| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 | +| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 | +| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 | +| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 | +| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 | +| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 | +| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 | +| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 | +| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 | +| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 | +| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 | +| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 | +| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 | +| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 | +| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 | +| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 | +| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 | +| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 | +| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 | +| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 | +| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 | +| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 | + + + + +Command for decoding using fine-tuned whisper: +```bash +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2_wenetspeech \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --start-index 0 --end-index 26 \ + --remove-whisper-encoder-input-length-restriction True \ + --manifest-dir data/fbank \ + --beam-size 1 --max-duration 50 +``` +Command for decoding using pretrained zipformer: +```bash +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 +cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "data/lang_bpe_2000/*" +ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt +ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data +wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt +mv words.txt ./data/lang_bpe_2000/ + +./zipformer/decode.py \ + --epoch 999 \ + --avg 1 \ + --blank-penalty 2.0 \ + --use-averaged-model false \ + --exp-dir ./zipformer/exp_pretrain \ + --max-duration 600 \ + --start-index 0 --end-index 26 \ + --manifest-dir data/fbank_kaldi \ + --decoding-method greedy_search +``` +Command for fusion the above decoding results from whisper and zipformer: +```bash +python local/whisper_zipformer_fusion.py \ + --whisper-log-dir ./whisper/exp_large_v2_wenetspeech \ + --zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \ + --output-log-dir ./results_fusion + +``` + +See why the fusion helps [here](./local/whisper_zipformer_fusion.py). + +SpeechIO fbank features, decoding scripts, logs, and decoding results +are available at + diff --git a/egs/speechio/ASR/local/compute_fbank_speechio.py b/egs/speechio/ASR/local/compute_fbank_speechio.py new file mode 100644 index 000000000..5b3489a9f --- /dev/null +++ b/egs/speechio/ASR/local/compute_fbank_speechio.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the ST-CMDS dataset. +It looks for manifests in the directory data/manifests/stcmds. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source. + + +def compute_fbank_speechio( + num_mel_bins: int = 80, + speed_perturb: bool = False, + fbank_dir: str = "data/fbank", + whisper_fbank: bool = False, +): + src_dir = Path("data/manifests") + output_dir = Path(fbank_dir) + num_jobs = min(8, os.cpu_count()) + + dataset_parts = [] + for i in range(SPEECHIO_TESTSET_INDEX + 1): + idx = f"{i}".zfill(2) + dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") + + prefix = "speechio" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + parser.add_argument( + "--fbank-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + compute_fbank_speechio( + num_mel_bins=args.num_mel_bins, + fbank_dir=args.fbank_dir, + whisper_fbank=args.whisper_fbank, + ) diff --git a/egs/speechio/ASR/local/display_manifest_statistics.py b/egs/speechio/ASR/local/display_manifest_statistics.py new file mode 100644 index 000000000..0c803bfcd --- /dev/null +++ b/egs/speechio/ASR/local/display_manifest_statistics.py @@ -0,0 +1,1162 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer_stateless/train.py +for usage. +""" + +SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source. + +from lhotse import load_manifest_lazy + + +def main(): + dataset_parts = [] + for i in range(SPEECHIO_TESTSET_INDEX + 1): + idx = f"{i}".zfill(2) + dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") + + prefix = "speechio" + suffix = "jsonl.gz" + + for partition in dataset_parts: + path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}" + cuts = load_manifest_lazy(path) + print( + f"===================Duration statistics of {partition}===================" + ) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +===================Duration statistics of SPEECHIO_ASR_ZH00000=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 879 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:36:09 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.6 │ +├───────────────────────────┼──────────┤ +│ std │ 2.0 │ +├───────────────────────────┼──────────┤ +│ min │ 1.7 │ +├───────────────────────────┼──────────┤ +│ 25% │ 5.0 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.5 │ +├───────────────────────────┼──────────┤ +│ 75% │ 8.1 │ +├───────────────────────────┼──────────┤ +│ 99% │ 11.2 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 11.6 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 12.2 │ +├───────────────────────────┼──────────┤ +│ max │ 12.5 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 879 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 879 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 879 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:36:09 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:36:09 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00001=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 5069 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 08:43:04 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.2 │ +├───────────────────────────┼──────────┤ +│ std │ 2.1 │ +├───────────────────────────┼──────────┤ +│ min │ 0.6 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.6 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.2 │ +├───────────────────────────┼──────────┤ +│ 75% │ 7.9 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.7 │ +├───────────────────────────┼──────────┤ +│ max │ 12.5 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 5069 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 5069 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 5069 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 08:43:04 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 08:43:04 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00002=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 2993 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:45:09 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.3 │ +├───────────────────────────┼──────────┤ +│ std │ 1.5 │ +├───────────────────────────┼──────────┤ +│ min │ 0.4 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.2 │ +├───────────────────────────┼──────────┤ +│ 50% │ 3.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 4.3 │ +├───────────────────────────┼──────────┤ +│ 99% │ 7.3 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 7.8 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ max │ 11.8 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 2993 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 2993 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 2993 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:45:09 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:45:09 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00003=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1683 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:23:28 │ +├───────────────────────────┼──────────┤ +│ mean │ 5.1 │ +├───────────────────────────┼──────────┤ +│ std │ 1.4 │ +├───────────────────────────┼──────────┤ +│ min │ 2.4 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.0 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 6.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.4 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.8 │ +├───────────────────────────┼──────────┤ +│ max │ 14.2 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1683 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1683 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1683 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:23:28 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:23:28 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00004=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1311 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:49:16 │ +├───────────────────────────┼──────────┤ +│ mean │ 7.7 │ +├───────────────────────────┼──────────┤ +│ std │ 2.8 │ +├───────────────────────────┼──────────┤ +│ min │ 0.9 │ +├───────────────────────────┼──────────┤ +│ 25% │ 5.8 │ +├───────────────────────────┼──────────┤ +│ 50% │ 8.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 12.9 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 13.5 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 13.8 │ +├───────────────────────────┼──────────┤ +│ max │ 14.4 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1311 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1311 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1311 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:49:16 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:49:16 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00005=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 3148 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 04:22:47 │ +├───────────────────────────┼──────────┤ +│ mean │ 5.0 │ +├───────────────────────────┼──────────┤ +│ std │ 1.4 │ +├───────────────────────────┼──────────┤ +│ min │ 2.0 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.9 │ +├───────────────────────────┼──────────┤ +│ 99% │ 8.8 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.3 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.3 │ +├───────────────────────────┼──────────┤ +│ max │ 11.1 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 3148 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 3148 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 3148 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 04:22:47 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 04:22:47 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00006=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1561 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:39:33 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.8 │ +├───────────────────────────┼──────────┤ +│ std │ 2.2 │ +├───────────────────────────┼──────────┤ +│ min │ 0.4 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.2 │ +├───────────────────────────┼──────────┤ +│ 50% │ 3.3 │ +├───────────────────────────┼──────────┤ +│ 75% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.4 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 11.3 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 15.3 │ +├───────────────────────────┼──────────┤ +│ max │ 23.8 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1561 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1561 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1561 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:39:33 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:39:33 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00007=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 770 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 00:58:57 │ +├───────────────────────────┼──────────┤ +│ mean │ 4.6 │ +├───────────────────────────┼──────────┤ +│ std │ 2.4 │ +├───────────────────────────┼──────────┤ +│ min │ 0.7 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.7 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.0 │ +├───────────────────────────┼──────────┤ +│ 75% │ 6.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 11.8 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 13.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 15.1 │ +├───────────────────────────┼──────────┤ +│ max │ 18.7 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 770 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 770 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 770 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 00:58:57 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 00:58:57 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00008=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 884 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:16:55 │ +├───────────────────────────┼──────────┤ +│ mean │ 5.2 │ +├───────────────────────────┼──────────┤ +│ std │ 2.3 │ +├───────────────────────────┼──────────┤ +│ min │ 1.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.5 │ +├───────────────────────────┼──────────┤ +│ 50% │ 5.0 │ +├───────────────────────────┼──────────┤ +│ 75% │ 6.4 │ +├───────────────────────────┼──────────┤ +│ 99% │ 11.3 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 12.7 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 16.2 │ +├───────────────────────────┼──────────┤ +│ max │ 18.5 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 884 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 884 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 884 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:16:55 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:16:55 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00009=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 3466 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 04:38:13 │ +├───────────────────────────┼──────────┤ +│ mean │ 4.8 │ +├───────────────────────────┼──────────┤ +│ std │ 1.9 │ +├───────────────────────────┼──────────┤ +│ min │ 1.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.4 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.5 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.9 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.5 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 11.3 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 12.5 │ +├───────────────────────────┼──────────┤ +│ max │ 13.1 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 3466 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 3466 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 3466 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 04:38:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 04:38:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00010=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 2251 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 04:12:54 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.7 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 1.4 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.5 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.3 │ +├───────────────────────────┼──────────┤ +│ 75% │ 8.5 │ +├───────────────────────────┼──────────┤ +│ 99% │ 14.9 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 15.5 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 15.8 │ +├───────────────────────────┼──────────┤ +│ max │ 16.2 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 2251 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 2251 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 2251 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 04:12:54 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 04:12:54 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00011=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1053 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 03:27:12 │ +├───────────────────────────┼──────────┤ +│ mean │ 11.8 │ +├───────────────────────────┼──────────┤ +│ std │ 3.4 │ +├───────────────────────────┼──────────┤ +│ min │ 1.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 11.5 │ +├───────────────────────────┼──────────┤ +│ 50% │ 13.0 │ +├───────────────────────────┼──────────┤ +│ 75% │ 13.9 │ +├───────────────────────────┼──────────┤ +│ 99% │ 15.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 15.1 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 20.7 │ +├───────────────────────────┼──────────┤ +│ max │ 22.2 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1053 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1053 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1053 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 03:27:12 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 03:27:12 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00012=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1170 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 03:23:34 │ +├───────────────────────────┼──────────┤ +│ mean │ 10.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.5 │ +├───────────────────────────┼──────────┤ +│ min │ 0.8 │ +├───────────────────────────┼──────────┤ +│ 25% │ 8.0 │ +├───────────────────────────┼──────────┤ +│ 50% │ 11.5 │ +├───────────────────────────┼──────────┤ +│ 75% │ 13.2 │ +├───────────────────────────┼──────────┤ +│ 99% │ 15.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 15.1 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 15.7 │ +├───────────────────────────┼──────────┤ +│ max │ 20.3 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1170 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1170 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1170 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 03:23:34 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 03:23:34 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00013=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1321 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:46:41 │ +├───────────────────────────┼──────────┤ +│ mean │ 4.8 │ +├───────────────────────────┼──────────┤ +│ std │ 1.5 │ +├───────────────────────────┼──────────┤ +│ min │ 0.9 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.8 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.8 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 8.5 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 9.5 │ +├───────────────────────────┼──────────┤ +│ max │ 9.7 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1321 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1321 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1321 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:46:41 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:46:41 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00014=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 856 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:00:39 │ +├───────────────────────────┼──────────┤ +│ mean │ 4.3 │ +├───────────────────────────┼──────────┤ +│ std │ 1.8 │ +├───────────────────────────┼──────────┤ +│ min │ 0.8 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.5 │ +├───────────────────────────┼──────────┤ +│ 99% │ 8.5 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.2 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 11.1 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 856 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 856 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 856 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:00:39 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:00:39 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00015=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1168 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:08:52 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.6 │ +├───────────────────────────┼──────────┤ +│ std │ 2.0 │ +├───────────────────────────┼──────────┤ +│ min │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 5.3 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.8 │ +├───────────────────────────┼──────────┤ +│ 75% │ 8.2 │ +├───────────────────────────┼──────────┤ +│ 99% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.1 │ +├───────────────────────────┼──────────┤ +│ max │ 15.5 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1168 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1168 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1168 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:08:52 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:08:52 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00016=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1201 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:00:46 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.0 │ +├───────────────────────────┼──────────┤ +│ std │ 2.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.9 │ +├───────────────────────────┼──────────┤ +│ 25% │ 1.6 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.3 │ +├───────────────────────────┼──────────┤ +│ 75% │ 3.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.5 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 9.7 │ +├───────────────────────────┼──────────┤ +│ max │ 9.9 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1201 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1201 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1201 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:00:46 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:00:46 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00017=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1271 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:47:57 │ +├───────────────────────────┼──────────┤ +│ mean │ 5.1 │ +├───────────────────────────┼──────────┤ +│ std │ 2.2 │ +├───────────────────────────┼──────────┤ +│ min │ 1.0 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.3 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 6.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 9.7 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.4 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1271 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1271 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1271 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:47:57 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:47:57 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00018=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 899 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 00:51:12 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.4 │ +├───────────────────────────┼──────────┤ +│ std │ 1.2 │ +├───────────────────────────┼──────────┤ +│ min │ 1.3 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.5 │ +├───────────────────────────┼──────────┤ +│ 50% │ 3.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 4.1 │ +├───────────────────────────┼──────────┤ +│ 99% │ 6.7 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 7.1 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 8.2 │ +├───────────────────────────┼──────────┤ +│ max │ 9.2 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 899 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 899 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 899 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 00:51:12 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 00:51:12 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00019=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 615 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 00:41:43 │ +├───────────────────────────┼──────────┤ +│ mean │ 4.1 │ +├───────────────────────────┼──────────┤ +│ std │ 1.5 │ +├───────────────────────────┼──────────┤ +│ min │ 1.3 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.8 │ +├───────────────────────────┼──────────┤ +│ 50% │ 3.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.2 │ +├───────────────────────────┼──────────┤ +│ 99% │ 7.9 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 8.1 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 8.6 │ +├───────────────────────────┼──────────┤ +│ max │ 8.8 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 615 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 615 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 615 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 00:41:43 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 00:41:43 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00020=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1590 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:10:54 │ +├───────────────────────────┼──────────┤ +│ mean │ 4.9 │ +├───────────────────────────┼──────────┤ +│ std │ 1.5 │ +├───────────────────────────┼──────────┤ +│ min │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.8 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 6.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 8.5 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 8.7 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 9.2 │ +├───────────────────────────┼──────────┤ +│ max │ 10.4 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1590 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1590 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1590 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:10:54 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:10:54 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00021=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1035 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:44:07 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼──────────┤ +│ std │ 1.8 │ +├───────────────────────────┼──────────┤ +│ min │ 1.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.7 │ +├───────────────────────────┼──────────┤ +│ 50% │ 5.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 7.3 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.4 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.6 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 11.0 │ +├───────────────────────────┼──────────┤ +│ max │ 11.1 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1035 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1035 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1035 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:44:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:44:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00022=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1026 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:40:43 │ +├───────────────────────────┼──────────┤ +│ mean │ 5.9 │ +├───────────────────────────┼──────────┤ +│ std │ 2.2 │ +├───────────────────────────┼──────────┤ +│ min │ 0.9 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.4 │ +├───────────────────────────┼──────────┤ +│ 50% │ 5.8 │ +├───────────────────────────┼──────────┤ +│ 75% │ 7.1 │ +├───────────────────────────┼──────────┤ +│ 99% │ 12.1 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 12.7 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 13.9 │ +├───────────────────────────┼──────────┤ +│ max │ 14.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1026 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1026 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1026 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:40:43 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:40:43 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00023=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1528 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:06:51 │ +├───────────────────────────┼──────────┤ +│ mean │ 5.0 │ +├───────────────────────────┼──────────┤ +│ std │ 2.5 │ +├───────────────────────────┼──────────┤ +│ min │ 0.5 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.1 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.5 │ +├───────────────────────────┼──────────┤ +│ 75% │ 6.6 │ +├───────────────────────────┼──────────┤ +│ 99% │ 12.3 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 13.9 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 15.8 │ +├───────────────────────────┼──────────┤ +│ max │ 16.8 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1528 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1528 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1528 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:06:51 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:06:51 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00024=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1930 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:39:02 │ +├───────────────────────────┼──────────┤ +│ mean │ 4.9 │ +├───────────────────────────┼──────────┤ +│ std │ 2.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.9 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.4 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.7 │ +├───────────────────────────┼──────────┤ +│ 75% │ 6.2 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.3 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.9 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 12.0 │ +├───────────────────────────┼──────────┤ +│ max │ 12.6 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1930 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1930 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1930 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:39:02 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:39:02 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00025=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1164 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:24:42 │ +├───────────────────────────┼──────────┤ +│ mean │ 4.4 │ +├───────────────────────────┼──────────┤ +│ std │ 1.9 │ +├───────────────────────────┼──────────┤ +│ min │ 0.9 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 4.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.6 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.4 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.9 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 12.5 │ +├───────────────────────────┼──────────┤ +│ max │ 13.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1164 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1164 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1164 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:24:42 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:24:42 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +===================Duration statistics of SPEECHIO_ASR_ZH00026=================== +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1336 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:25:38 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.5 │ +├───────────────────────────┼──────────┤ +│ std │ 2.3 │ +├───────────────────────────┼──────────┤ +│ min │ 0.5 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.8 │ +├───────────────────────────┼──────────┤ +│ 75% │ 8.3 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.4 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 11.9 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 12.9 │ +├───────────────────────────┼──────────┤ +│ max │ 13.3 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1336 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 1336 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1336 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:25:38 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:25:38 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +""" diff --git a/egs/speechio/ASR/local/whisper_zipformer_fusion.py b/egs/speechio/ASR/local/whisper_zipformer_fusion.py new file mode 100644 index 000000000..04c5e75f0 --- /dev/null +++ b/egs/speechio/ASR/local/whisper_zipformer_fusion.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 Author: Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file uses whisper and zipformer decoding results to generate fusion decoding results. +Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors, +we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors. + +Usage: + python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import kaldialign + +from icefall.utils import store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--whisper-log-dir", + type=str, + default="./recogs_whisper", + help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt", + ) + parser.add_argument( + "--zipformer-log-dir", + type=str, + default="./recogs_zipformer", + help="The directory to store the zipformer logs", + ) + parser.add_argument( + "--output-log-dir", + type=str, + default="./results_fusion", + help="The directory to store the fusion logs", + ) + return parser + + +def save_results( + res_dir: Path, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + + suffix = "epoch-999-avg-1" + + for key, results in results_dict.items(): + recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + print(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + print("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + print(s) + + +def extract_hyp_ref_wavname(filename): + """ + 0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升'] + 0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升 + """ + hyps, refs, wav_name = [], [], [] + with open(filename, "r") as f: + for line in f: + if "ref" in line: + ref = line.split("ref=")[1].strip() + ref = ref[2:-2] + list_elements = ref.split("', '") + ref = "".join(list_elements) + refs.append(ref) + elif "hyp" in line: + hyp = line.split("hyp=")[1].strip() + hyps.append(hyp) + wav_name.append(line.split(":")[0]) + return hyps, refs, wav_name + + +def get_pair_filenames( + whisper_log_dir, + zipformer_log_dir, + whisper_suffix="beam-search-epoch-999-avg-1", + zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0", +): + results = [] + start_index, end_index = 0, 26 + dataset_parts = [] + for i in range(start_index, end_index + 1): + idx = f"{i}".zfill(2) + dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") + for partition in dataset_parts: + whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt" + zipformer_filename = ( + f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt" + ) + results.append((whisper_filename, zipformer_filename)) + return results + + +def fusion_hyps_trust_substituion_insertion( + hyps_whisper, hyps_zipformer, refs, ERR="*" +): + """ + alignment example: + [('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')] + left is whisper, right is zipformer + for whisper substitution, use left + for whisper insertion, use left + for whisper deletion, use right + """ + hyps_fusion = [] + for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs): + ali = kaldialign.align(hyp_w, hyp_z, ERR) + hyp_f = "" + for a in ali: + if a[0] == ERR: + hyp_f += a[1] + else: + hyp_f += a[0] + hyps_fusion.append(hyp_f) + return hyps_fusion + + +def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"): + """ + alignment example: + [('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')] + left is whisper, right is zipformer + for whisper substitution, use left + for whisper insertion, use right + for whisper deletion, use right + """ + hyps_fusion = [] + for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs): + ali = kaldialign.align(hyp_w, hyp_z, ERR) + hyp_f = "" + for a in ali: + if a[0] == ERR: + hyp_f += a[1] + elif a[1] == ERR: + pass + else: + hyp_f += a[0] + hyps_fusion.append(hyp_f) + return hyps_fusion + + +def main(): + parser = get_parser() + args = parser.parse_args() + # mkdir output_log_dir + Path(args.output_log_dir).mkdir(parents=True, exist_ok=True) + pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir) + for pair in pair_logs: + hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0]) + hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1]) + + hyps_fusion = fusion_hyps_trust_substituion_insertion( + hyps_whisper, hyps_zipformer, refs + ) + + partition_name = pair[0].split("/")[-1].split("-")[1] + save_results( + Path(args.output_log_dir), + partition_name, + {"fusion": list(zip(wav_name, refs, hyps_fusion))}, + ) + + print(f"Processed {partition_name}") + + +if __name__ == "__main__": + main() diff --git a/egs/speechio/ASR/prepare.sh b/egs/speechio/ASR/prepare.sh new file mode 100644 index 000000000..048a66d8f --- /dev/null +++ b/egs/speechio/ASR/prepare.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + +set -eou pipefail + +stage=3 +stop_stage=3 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/SPEECHIO_ASR_ZH00000 +# This directory contains the following files downloaded from +# https://github.com/SpeechColab/Leaderboard +# +# - metadata.tsv +# - wav +# - wav.scp +# - trans.txt +# + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare speechio manifest" + # We assume that you have downloaded the speechio dataset + # to $dl_dir + mkdir -p data/manifests + if [ ! -e data/manifests/.speechio.done ]; then + lhotse prepare speechio $dl_dir data/manifests + touch data/manifests/.speechio.done + fi +fi + +whisper_mel_bins=80 +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute whisper fbank for speechio" + if [ ! -f data/fbank/.speechio.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_speechio.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.speechio.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute kaldi fbank for speechio" + if [ ! -f data/fbank/.speechio.kaldi.done ]; then + fbank_dir=data/fbank_kaldi + mkdir -p $fbank_dir + ./local/compute_fbank_speechio.py --fbank-dir $fbank_dir + touch data/fbank/.speechio.kaldi.done + fi +fi diff --git a/egs/speechio/ASR/shared b/egs/speechio/ASR/shared new file mode 120000 index 000000000..9d8803a7d --- /dev/null +++ b/egs/speechio/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared// \ No newline at end of file diff --git a/egs/speechio/ASR/whisper/asr_datamodule.py b/egs/speechio/ASR/whisper/asr_datamodule.py new file mode 100644 index 000000000..7382fd3f5 --- /dev/null +++ b/egs/speechio/ASR/whisper/asr_datamodule.py @@ -0,0 +1,195 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrDataModule: + """ + DataModule for k2 ASR experiments. + There is no train and valid dataloader, for speechio dataset + but there can be multiple test dataloaders. + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=300.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + parser.add_argument( + "--start-index", + type=int, + default=0, + help="Decoding will start from dataset SPEECHIO_ASR_ZH000index", + ) + + parser.add_argument( + "--end-index", + type=int, + default=26, + help="Decoding will end with dataset SPEECHIO_ASR_ZH000index", + ) + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl diff --git a/egs/speechio/ASR/whisper/decode.py b/egs/speechio/ASR/whisper/decode.py new file mode 100644 index 000000000..001367791 --- /dev/null +++ b/egs/speechio/ASR/whisper/decode.py @@ -0,0 +1,520 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Wei Kang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# Command for decoding using fine-tuned models: +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --beam-size 10 --max-duration 50 + +# Command for decoding using pretrained models (before fine-tuning): + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2_pretrained \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --start-index 14 --end-index 15 \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 1 --max-duration 50 + +""" + +import argparse +import logging +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +import whisper +from asr_datamodule import AsrDataModule +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from tn.chinese.normalizer import Normalizer +from whisper.normalizers import BasicTextNormalizer +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward +from zhconv import convert + +from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def average_checkpoints( + filenames: List[Path], device: torch.device = torch.device("cpu") +) -> dict: + """Average a list of checkpoints. + The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. + + Args: + filenames: + Filenames of the checkpoints to be averaged. We assume all + checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. + Returns: + Return a dict (i.e., state_dict) which is the average of all + model state dicts contained in the checkpoints. + """ + n = len(filenames) + + if "model" in torch.load(filenames[0], map_location=device): + avg = torch.load(filenames[0], map_location=device)["model"] + else: + avg = torch.load(filenames[0], map_location=device) + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + + for i in range(1, n): + if "model" in torch.load(filenames[i], map_location=device): + state_dict = torch.load(filenames[i], map_location=device)["model"] + else: + state_dict = torch.load(filenames[i], map_location=device) + for k in uniqued_names: + avg[k] += state_dict[k] + + for k in uniqued_names: + if avg[k].is_floating_point(): + avg[k] /= n + else: + avg[k] //= n + + return avg + + +def remove_punctuation(text: str or List[str]): + """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py + + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings without any punctuation. + """ + punctuation = "!,.;:?、!,。;:?《》 " + if isinstance(text, str): + text = re.sub(r"[{}]+".format(punctuation), "", text).strip() + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = re.sub(r"[{}]+".format(punctuation), "", t).strip() + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type {type(text)}") + + +def to_simple(text: str or List[str]): + """Convert traditional Chinese to simplified Chinese. + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings converted to simplified Chinese. + """ + if isinstance(text, str): + text = convert(text, "zh-cn") + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = convert(t, "zh-cn") + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type{type(text)}") + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=-1, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="beam-search", + help="""Decoding method. + Supported values are: + - beam-search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=1, + help="beam size for beam search decoding", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--remove-whisper-encoder-input-length-restriction", + type=str2bool, + default=True, + help="replace whisper encoder forward method to remove input length restriction", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: "beam-search" + - value: A list of lists. Each sublist is a list of token IDs. + Args: + params: + It is returned by :func:`get_params`. + model: + The neural model. + batch: + It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. + Returns: + Return a dict, whose key may be "beam-search". + """ + dtype = torch.float16 + device = torch.device("cuda") + + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device, dtype=dtype).transpose(1, 2) + if not params.remove_whisper_encoder_input_length_restriction: + T = 3000 + if feature.shape[2] < T: + feature = torch.cat( + [ + feature, + torch.zeros( + feature.shape[0], feature.shape[1], T - feature.shape[2] + ).to(device, dtype=dtype), + ], + 2, + ) + + supervisions = batch["supervisions"] + feature_len = supervisions["num_frames"] + feature_len = feature_len.to(device, dtype=dtype) + results = model.decode(feature, params.decoding_options) + hyps = [result.text for result in results] + + hyps = remove_punctuation(hyps) + hyps = to_simple(hyps) + hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + print(hyps) + return {"beam-search": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + The dataloader. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "beam-search". + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + batch=batch, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + setup_logger( + f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + ) + + options = whisper.DecodingOptions( + task="transcribe", + language="zh", + without_timestamps=True, + beam_size=params.beam_size, + ) + params.decoding_options = options + params.cleaner = BasicTextNormalizer() + params.normalizer = Normalizer() + + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + + logging.info(f"device: {device}") + + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + if params.epoch > 0: + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + model.load_state_dict(average_checkpoints(filenames)) + else: + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + # save checkpoints + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(model.state_dict(), filename) + else: + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index) + + def remove_long_utt(c: Cut): + # Keep only utterances with duration in 30 seconds + # + if c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dls = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/speechio/ASR/whisper/multi_dataset.py b/egs/speechio/ASR/whisper/multi_dataset.py new file mode 100644 index 000000000..f55d45394 --- /dev/null +++ b/egs/speechio/ASR/whisper/multi_dataset.py @@ -0,0 +1,59 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import logging +import re +from pathlib import Path +from typing import Dict, List + +import lhotse +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, fbank_dir: str, start_index: int = 0, end_index: int = 26): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - speechio_cuts_SPEECHIO_ASR_ZH00000.jsonl.gz + ... + - speechio_cuts_SPEECHIO_ASR_ZH00026.jsonl.gz + """ + self.fbank_dir = Path(fbank_dir) + self.start_index = start_index + self.end_index = end_index + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + dataset_parts = [] + for i in range(self.start_index, self.end_index + 1): + idx = f"{i}".zfill(2) + dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") + + prefix = "speechio" + suffix = "jsonl.gz" + + results_dict = {} + for partition in dataset_parts: + path = f"{prefix}_cuts_{partition}.{suffix}" + + logging.info(f"Loading {path} set in lazy mode") + test_cuts = load_manifest_lazy(self.fbank_dir / path) + results_dict[partition] = test_cuts + + return results_dict diff --git a/egs/speechio/ASR/whisper/requirements.txt b/egs/speechio/ASR/whisper/requirements.txt new file mode 120000 index 000000000..744bf8bb6 --- /dev/null +++ b/egs/speechio/ASR/whisper/requirements.txt @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/requirements.txt \ No newline at end of file diff --git a/egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py new file mode 120000 index 000000000..2a7808921 --- /dev/null +++ b/egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/asr_datamodule.py b/egs/speechio/ASR/zipformer/asr_datamodule.py new file mode 120000 index 000000000..bf446dabe --- /dev/null +++ b/egs/speechio/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../whisper/asr_datamodule.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/beam_search.py b/egs/speechio/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/speechio/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/ctc_decode.py b/egs/speechio/ASR/zipformer/ctc_decode.py new file mode 100644 index 000000000..f9d0db993 --- /dev/null +++ b/egs/speechio/ASR/zipformer/ctc_decode.py @@ -0,0 +1,623 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import get_lattice, one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + H=H, + bpe_model=bpe_model, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = list(ref_text.replace(" ", "")) + hyp_words = list("".join(hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ("ctc-decoding",) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=True, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + + G = None + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index) + + test_sets_cuts = multi_dataset.test_cuts() + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/speechio/ASR/zipformer/decode.py b/egs/speechio/ASR/zipformer/decode.py new file mode 100644 index 000000000..ffdd7b500 --- /dev/null +++ b/egs/speechio/ASR/zipformer/decode.py @@ -0,0 +1,843 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list(str(text).replace(" ", "")) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + hyp_text = "".join(hyp_words) + this_batch.append((cut_id, ref_text, hyp_text)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + params.suffix += f"-blank-penalty-{params.blank_penalty}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/speechio/ASR/zipformer/decoder.py b/egs/speechio/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/speechio/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/encoder_interface.py b/egs/speechio/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/speechio/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/joiner.py b/egs/speechio/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/speechio/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/model.py b/egs/speechio/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/speechio/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/multi_dataset.py b/egs/speechio/ASR/zipformer/multi_dataset.py new file mode 120000 index 000000000..af164667a --- /dev/null +++ b/egs/speechio/ASR/zipformer/multi_dataset.py @@ -0,0 +1 @@ +../whisper/multi_dataset.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/optim.py b/egs/speechio/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/speechio/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/scaling.py b/egs/speechio/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/speechio/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/scaling_converter.py b/egs/speechio/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/speechio/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/subsampling.py b/egs/speechio/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/speechio/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/train.py b/egs/speechio/ASR/zipformer/train.py new file mode 120000 index 000000000..ad7216cf7 --- /dev/null +++ b/egs/speechio/ASR/zipformer/train.py @@ -0,0 +1 @@ +../../../multi_zh-hans/ASR/zipformer/train.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/zipformer.py b/egs/speechio/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/speechio/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index 1af08fee2..ac4e92ec5 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -16,11 +16,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging from pathlib import Path import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -30,8 +38,31 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) torch.multiprocessing.set_sharing_strategy("file_system") +from icefall.utils import str2bool -def compute_fbank_wenetspeech_dev_test(): + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser + + +def compute_fbank_wenetspeech_dev_test(args): in_out_dir = Path("data/fbank") # number of workers in dataloader num_workers = 42 @@ -44,7 +75,12 @@ def compute_fbank_wenetspeech_dev_test(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + if args.whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) + else: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") @@ -82,7 +118,11 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_wenetspeech_dev_test() + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + compute_fbank_wenetspeech_dev_test(args) if __name__ == "__main__": diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index a87801462..804a302bd 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -22,15 +22,19 @@ from datetime import datetime from pathlib import Path import torch -from lhotse import ( +from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig, CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, set_audio_duration_mismatch_tolerance, set_caching_enabled, ) +from icefall.utils import get_executor, str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -87,6 +91,27 @@ def get_parser(): default=-1, help="Stop processing pieces until this number (excluded).", ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + + parser.add_argument( + "--output-dir-prefix", + type=str, + default="", + help="Prefix of the output directory.", + ) return parser @@ -96,6 +121,7 @@ def compute_fbank_wenetspeech_splits(args): num_splits = args.num_splits output_dir = f"data/fbank/{subset}_split_{num_splits}" output_dir = Path(output_dir) + output_dir = Path(args.output_dir_prefix) / output_dir assert output_dir.exists(), f"{output_dir} does not exist!" num_digits = len(str(num_splits)) @@ -110,14 +136,21 @@ def compute_fbank_wenetspeech_splits(args): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + if args.whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) + ) + # extractor = KaldifeatWhisperFbank(KaldifeatWhisperFbankConfig(num_filters=args.num_mel_bins, device=device)) + else: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance set_caching_enabled(False) + # with get_executor() as ex: # Initialize the executor only once. for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) - logging.info(f"Processing {idx}/{num_splits}") + idx = f"{i}".zfill(num_digits) + logging.info(f"Processing {i+1}/{num_splits}") cuts_path = output_dir / f"cuts_{subset}.{idx}.jsonl.gz" if cuts_path.is_file(): @@ -143,7 +176,6 @@ def compute_fbank_wenetspeech_splits(args): storage_type=LilcomChunkyWriter, overwrite=True, ) - logging.info(f"Saving to {cuts_path}") cut_set.to_file(cuts_path) diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 543d19ce0..e3e28bd24 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -182,6 +182,43 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then fi fi +whisper_mel_bins=80 +if [ $stage -le 129 ] && [ $stop_stage -ge 129 ]; then + log "Stage 129: compute whisper fbank for dev and test sets" + python3 ./local/compute_fbank_wenetspeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true +fi +if [ $stage -le 130 ] && [ $stop_stage -ge 130 ]; then + log "Stage 130: Comute features for whisper training set" + + split_dir=data/fbank/L_split_${num_splits} + if [ ! -f $split_dir/.split_completed ]; then + lhotse split $num_splits ./data/fbank/cuts_L_raw.jsonl.gz $split_dir + touch $split_dir/.split_completed + fi + + python3 ./local/compute_fbank_wenetspeech_splits.py \ + --training-subset L \ + --num-workers 8 \ + --batch-duration 1600 \ + --start 0 \ + --num-mel-bins ${whisper_mel_bins} --whisper-fbank true \ + --num-splits $num_splits + + if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then + pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz") + lhotse combine $pieces data/fbank/cuts_L.jsonl.gz + fi +fi + +if [ $stage -le 131 ] && [ $stop_stage -ge 131 ]; then + log "Stage 131: concat feats into train set" + if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then + pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz") + lhotse combine $pieces data/fbank/cuts_L.jsonl.gz + fi +fi + + if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then log "Stage 14: Compute fbank for musan" mkdir -p data/fbank @@ -272,7 +309,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then mkdir -p $text_out_dir log "Genearating training text data" - + if [ ! -f $text_out_dir/lm_data.pt ]; then ./local/prepare_char_lm_training_data.py \ --lang-char data/lang_char \ @@ -281,14 +318,14 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then fi log "Generating DEV text data" - # prepare validation text data + # prepare validation text data if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then valid_text=${text_out_dir}/ gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \ | jq '.text' | sed 's/"//g' \ | ./local/text2token.py -t "char" > $text_out_dir/valid_text - + python3 ./local/text2segments.py \ --num-process $nj \ --input-file $text_out_dir/valid_text \ @@ -300,7 +337,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then --lm-data $text_out_dir/valid_text_words_segmentation \ --lm-archive $text_out_dir/lm_data_valid.pt - # prepare TEST text data + # prepare TEST text data if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then log "Prepare text for test set." for test_set in TEST_MEETING TEST_NET; do @@ -313,7 +350,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then --input-file $text_out_dir/${test_set}_text \ --output-file $text_out_dir/${test_set}_text_words_segmentation done - + cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation fi diff --git a/egs/wenetspeech/ASR/whisper/asr_datamodule.py b/egs/wenetspeech/ASR/whisper/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/wenetspeech/ASR/whisper/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/decode.py b/egs/wenetspeech/ASR/whisper/decode.py new file mode 100755 index 000000000..103f8d725 --- /dev/null +++ b/egs/wenetspeech/ASR/whisper/decode.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Wei Kang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# Command for decoding using fine-tuned models: +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --beam-size 10 --max-duration 50 + +# Command for decoding using pretrained models (before fine-tuning): + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 10 --max-duration 50 + +""" + +import argparse +import logging +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +import whisper +from asr_datamodule import WenetSpeechAsrDataModule +from lhotse.cut import Cut +from tn.chinese.normalizer import Normalizer +from whisper.normalizers import BasicTextNormalizer +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward +from zhconv import convert + +from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def average_checkpoints( + filenames: List[Path], device: torch.device = torch.device("cpu") +) -> dict: + """Average a list of checkpoints. + The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. + + Args: + filenames: + Filenames of the checkpoints to be averaged. We assume all + checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. + Returns: + Return a dict (i.e., state_dict) which is the average of all + model state dicts contained in the checkpoints. + """ + n = len(filenames) + + if "model" in torch.load(filenames[0], map_location=device): + avg = torch.load(filenames[0], map_location=device)["model"] + else: + avg = torch.load(filenames[0], map_location=device) + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + + for i in range(1, n): + if "model" in torch.load(filenames[i], map_location=device): + state_dict = torch.load(filenames[i], map_location=device)["model"] + else: + state_dict = torch.load(filenames[i], map_location=device) + for k in uniqued_names: + avg[k] += state_dict[k] + + for k in uniqued_names: + if avg[k].is_floating_point(): + avg[k] /= n + else: + avg[k] //= n + + return avg + + +def remove_punctuation(text: str or List[str]): + """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py + + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings without any punctuation. + """ + punctuation = "!,.;:?、!,。;:?《》 " + if isinstance(text, str): + text = re.sub(r"[{}]+".format(punctuation), "", text).strip() + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = re.sub(r"[{}]+".format(punctuation), "", t).strip() + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type {type(text)}") + + +def to_simple(text: str or List[str]): + """Convert traditional Chinese to simplified Chinese. + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings converted to simplified Chinese. + """ + if isinstance(text, str): + text = convert(text, "zh-cn") + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = convert(t, "zh-cn") + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type{type(text)}") + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=-1, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="beam-search", + help="""Decoding method. + Supported values are: + - beam-search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=1, + help="beam size for beam search decoding", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--remove-whisper-encoder-input-length-restriction", + type=str2bool, + default=True, + help="replace whisper encoder forward method to remove input length restriction", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: "beam-search" + - value: A list of lists. Each sublist is a list of token IDs. + Args: + params: + It is returned by :func:`get_params`. + model: + The neural model. + batch: + It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. + Returns: + Return a dict, whose key may be "beam-search". + """ + dtype = torch.float16 + device = torch.device("cuda") + + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device, dtype=dtype).transpose(1, 2) + if not params.remove_whisper_encoder_input_length_restriction: + T = 3000 + if feature.shape[2] < T: + feature = torch.cat( + [ + feature, + torch.zeros( + feature.shape[0], feature.shape[1], T - feature.shape[2] + ).to(device, dtype=dtype), + ], + 2, + ) + + supervisions = batch["supervisions"] + feature_len = supervisions["num_frames"] + feature_len = feature_len.to(device, dtype=dtype) + results = model.decode(feature, params.decoding_options) + hyps = [result.text for result in results] + + hyps = remove_punctuation(hyps) + hyps = to_simple(hyps) + hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + print(hyps) + return {"beam-search": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + The dataloader. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "beam-search". + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + batch=batch, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + setup_logger( + f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + ) + + options = whisper.DecodingOptions( + task="transcribe", + language="zh", + without_timestamps=True, + beam_size=params.beam_size, + ) + params.decoding_options = options + params.cleaner = BasicTextNormalizer() + params.normalizer = Normalizer() + + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + + logging.info(f"device: {device}") + + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + if params.epoch > 0: + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + model.load_state_dict(average_checkpoints(filenames)) + else: + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + # save checkpoints + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(model.state_dict(), filename) + else: + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + wenetspeech = WenetSpeechAsrDataModule(args) + dev_cuts = wenetspeech.valid_cuts() + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) + + def remove_long_utt(c: Cut): + # Keep only utterances with duration in 30 seconds + # + if c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + test_net_cuts = wenetspeech.test_net_cuts() + test_net_cuts = test_net_cuts.filter(remove_long_utt) + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) + + # test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] + # test_dls = [dev_dl, test_net_dl, test_meeting_dl] + + test_sets = ["TEST_NET"] + test_dls = [test_net_dl] + + # test_sets = ["TEST_MEETING"] + # test_dls = [test_meeting_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/whisper/ds_config_zero1.json b/egs/wenetspeech/ASR/whisper/ds_config_zero1.json new file mode 120000 index 000000000..af7162d6c --- /dev/null +++ b/egs/wenetspeech/ASR/whisper/ds_config_zero1.json @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/ds_config_zero1.json \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/label_smoothing.py b/egs/wenetspeech/ASR/whisper/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/wenetspeech/ASR/whisper/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/optim.py b/egs/wenetspeech/ASR/whisper/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/wenetspeech/ASR/whisper/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/requirements.txt b/egs/wenetspeech/ASR/whisper/requirements.txt new file mode 120000 index 000000000..744bf8bb6 --- /dev/null +++ b/egs/wenetspeech/ASR/whisper/requirements.txt @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/requirements.txt \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py new file mode 100644 index 000000000..4b7c1ca42 --- /dev/null +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -0,0 +1,955 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +#fine-tuning with deepspeed zero stage 1 +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --deepspeed \ + --deepspeed_config ./whisper/ds_config_zero1.json + +# fine-tuning with ddp +torchrun --nproc_per_node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_medium \ + --base-lr 1e-5 \ + --model-name medium +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import deepspeed +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import whisper +from asr_datamodule import WenetSpeechAsrDataModule +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from label_smoothing import LabelSmoothingLoss +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.functional import pad as pad_tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import update_averaged_model +from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=1e-5, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser = deepspeed.add_config_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - frame_shift_ms: The frame shift in milliseconds. + - allowed_excess_duration_ratio: The allowed excess duration ratio. + - best_train_loss: The best training loss so far. + - best_valid_loss: The best validation loss so far. + - best_train_epoch: The epoch where the best training loss is achieved. + - best_valid_epoch: The epoch where the best validation loss is achieved. + - batch_idx_train: The batch index of the current batch. + - log_interval: Log training stats every `log_interval` batches. + - reset_interval: Reset the stats every `reset_interval` batches. + - valid_interval: Run validation every `valid_interval` batches. + - env_info: The environment information. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "subsampling_factor": 2, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute the loss for the given batch. + Args: + params: + It is returned by :func:`get_params`. + tokenizer: + The tokenizer used to encode the text. + model: + The model for training. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + Whether it is training. + Returns: + Return a tuple of two elements. The first element is the loss tensor. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: + padding_size = max(tensor.shape[0] for tensor in tensors) + dims = len(tensors[0].shape) + padded_tensors = [] + for tensor in tensors: + padding = [0] * 2 * dims + padding[-1] = padding_size - tensor.shape[0] + padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) + return torch.stack([tensor for tensor in padded_tensors], dim=0) + + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + + assert feature.ndim == 3 + feature = feature.to(device) + feature = feature.transpose(1, 2) # (N, C, T) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + # remove spaces in texts + texts = [text.replace(" ", "") for text in texts] + + text_tokens_list = [ + list(tokenizer.sot_sequence_including_notimestamps) + + tokenizer.encode(text) + + [tokenizer.eot] + for text in texts + ] + # convert it to torch tensor + text_tokens_list = [ + torch.LongTensor(text_tokens) for text_tokens in text_tokens_list + ] + + # 50256 is the index of for all whisper models + prev_outputs_tokens = _batch_tensors( + [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 + ) + target_tokens = _batch_tensors( + [tokens[1:] for tokens in text_tokens_list], pad_value=50256 + ) + target_lengths = torch.LongTensor( + [tokens.shape[0] - 1 for tokens in text_tokens_list] + ) + + decoder_criterion = LabelSmoothingLoss( + ignore_index=50256, label_smoothing=0.1, reduction="sum" + ) + + # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> + ignore_prefix_size = 3 + with torch.set_grad_enabled(is_training): + encoder_out = model.encoder(feature) + text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) + text_logits = text_logits[:, ignore_prefix_size:, :] + target_tokens = target_tokens[:, ignore_prefix_size:] + loss = decoder_criterion(text_logits, target_tokens.to(device)) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + tokenizer=tokenizer, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + if params.deepspeed: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + client_state={}, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + ) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + if params.deepspeed: + # deepspeed's backward() is different from torch's backward() + # in that it does not accept a loss tensor as input. + # It computes the loss internally. + model.backward(loss) + model.step() + else: + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + and not params.deepspeed + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + if batch_idx % params.log_interval == 0: + try: + cur_lr = scheduler.get_last_lr()[0] + except: # noqa + cur_lr = 0.0 + cur_grad_scale = ( + scaler._scale.item() + if (params.use_fp16 and not params.deepspeed) + else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if (params.use_fp16 and not params.deepspeed) + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info(params) + + logging.info("About to create model") + + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + del model.alignment_heads + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + tokenizer = whisper.tokenizer.get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language="zh", + task="transcribe", + ) + + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + else: + device = torch.device("cpu") + logging.info(f"Device: {device}") + model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if world_size > 1: + if params.deepspeed: + logging.info("Using DeepSpeed") + model, optimizer, _, scheduler = deepspeed.initialize( + args=params, model=model, model_parameters=model.parameters() + ) + else: + logging.info("Using DDP") + setup_dist(use_ddp_launch=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + wenetspeech = WenetSpeechAsrDataModule(args) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 15.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = wenetspeech.train_cuts() + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = wenetspeech.train_dataloaders(train_cuts) + valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts()) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + logging.info(f"start training from epoch {params.start_epoch}") + for epoch in range(params.start_epoch, params.num_epochs + 1): + if not params.deepspeed: + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + tokenizer=tokenizer, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.deepspeed: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}", + client_state={}, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", + tag=f"epoch-{params.cur_epoch}", + ) + else: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1 and not params.deepspeed: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = get_world_size() + rank = get_rank() + + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + run(rank=rank, world_size=world_size, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py new file mode 120000 index 000000000..2a7808921 --- /dev/null +++ b/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py \ No newline at end of file From ae61bd4090ad3e8c981aa77cc4fc417d095962c4 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 8 Mar 2024 11:01:11 +0800 Subject: [PATCH 119/216] Minor fixes for the `commonvoice` recipe (#1534) * init commit * fix for issue https://github.com/k2-fsa/icefall/issues/1531 * minor fixes --- egs/commonvoice/ASR/local/compile_hlg.py | 169 +++++++++++++++++- egs/commonvoice/ASR/local/compile_lg.py | 150 +++++++++++++++- .../ASR/local/preprocess_commonvoice.py | 9 + .../asr_datamodule.py | 8 +- .../commonvoice_fr.py | 14 +- .../zipformer_prompt_asr/asr_datamodule.py | 8 +- 6 files changed, 344 insertions(+), 14 deletions(-) mode change 120000 => 100755 egs/commonvoice/ASR/local/compile_hlg.py mode change 120000 => 100755 egs/commonvoice/ASR/local/compile_lg.py diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/commonvoice/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py new file mode 100755 index 000000000..6512aa68b --- /dev/null +++ b/egs/commonvoice/ASR/local/compile_hlg.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_n_gram.fst.txt + +The generated HLG is saved in $lang_dir/HLG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") + H = k2.ctc_topo(max_token_id) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"{lang_dir}/lm/{lm}.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info(f"Loading {lm}.fst.txt") + with open(f"{lang_dir}/lm/{lm}.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") + + logging.info("Connecting LG") + HLG = k2.connect(HLG) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + logging.info(f"HLG.shape: {HLG.shape}") + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lang_dir, args.lm) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/commonvoice/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py new file mode 100755 index 000000000..76dacb5b2 --- /dev/null +++ b/egs/commonvoice/ASR/local/compile_lg.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Kang Wei, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates LG from + + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from lang_dir/lm/G_3_gram.fst.txt + +The generated LG is saved in $lang_dir/LG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) + + return parser.parse_args() + + +def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FSA representing LG. + """ + lexicon = Lexicon(lang_dir) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"{lang_dir}/lm/{lm}.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info(f"Loading {lm}.fst.txt") + with open(f"{lang_dir}/lm/{lm}.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + return LG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "LG.pt").is_file(): + logging.info(f"{lang_dir}/LG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + LG = compile_LG(lang_dir, args.lm) + logging.info(f"Saving LG.pt to {lang_dir}") + torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index c0f4ca427..dbacdd821 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -52,6 +52,15 @@ def normalize_text(utt: str, language: str) -> str: return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper() elif language == "pl": return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper() + elif language == "yue": + return ( + utt.replace(" ", "") + .replace(",", "") + .replace("。", " ") + .replace("?", "") + .replace("!", "") + .replace("?", "") + ) else: raise NotImplementedError( f""" diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index c40d9419b..41009831c 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py index 79cf86b84..da8e62034 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 1a4c9a532..552f63905 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( From 60986c3ac159a1ffc75f6cca4392518d5f927f1f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Mar 2024 20:47:13 +0800 Subject: [PATCH 120/216] Fix default value for --context-size in icefall. (#1538) --- .../ASR/pruned_transducer_stateless7_streaming/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py index 99110d6b6..0e783e92b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -250,7 +250,7 @@ def get_parser(): parser.add_argument( "--context-size", type=int, - default=1, + default=2, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( From e472fa68403ebe56112cfb581adbd88b8f1f1f3e Mon Sep 17 00:00:00 2001 From: jimmy1984xu Date: Mon, 11 Mar 2024 18:37:26 +0800 Subject: [PATCH 121/216] fix CutMix init parameter (#1543) Co-authored-by: jimmyxu --- .../pruned_transducer_stateless7_streaming/commonvoice_fr.py | 2 +- egs/libriheavy/ASR/zipformer/asr_datamodule.py | 2 +- egs/multi_zh_en/ASR/zipformer/asr_datamodule.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py index da8e62034..91220bd11 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py @@ -232,7 +232,7 @@ class CommonVoiceAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py index e23c9b1b7..4985f3f4c 100644 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -232,7 +232,7 @@ class LibriHeavyAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py index 662ae01c5..489b38e65 100644 --- a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py @@ -216,7 +216,7 @@ class AsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") From 959906e9dcded13ef0618c6ea815a18ed3f19324 Mon Sep 17 00:00:00 2001 From: BannerWang <45093113+Banner-Wang@users.noreply.github.com> Date: Tue, 12 Mar 2024 12:44:09 +0800 Subject: [PATCH 122/216] Correct alimeeting download link (#1544) Co-authored-by: BannerWang --- egs/alimeeting/ASR/prepare.sh | 2 +- egs/alimeeting/ASR_v2/prepare.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 301ab0111..996a1da2d 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -15,7 +15,7 @@ perturb_speed=true # # - $dl_dir/alimeeting # This directory contains the following files downloaded from -# https://openslr.org/62/ +# https://openslr.org/119/ # # - Train_Ali_far.tar.gz # - Train_Ali_near.tar.gz diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh index 1098840f8..15c20692d 100755 --- a/egs/alimeeting/ASR_v2/prepare.sh +++ b/egs/alimeeting/ASR_v2/prepare.sh @@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting # # - $dl_dir/alimeeting # This directory contains the following files downloaded from -# https://openslr.org/62/ +# https://openslr.org/119/ # # - Train_Ali_far.tar.gz # - Train_Ali_near.tar.gz From 81f518ea7c4dc1e709bb10f21aac55dd33712649 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Mar 2024 22:29:21 +0800 Subject: [PATCH 123/216] Support different tts model types. (#1541) --- docs/source/recipes/TTS/ljspeech/vits.rst | 17 ++++-- egs/ljspeech/TTS/README.md | 73 +++++++++++++++++++++-- egs/ljspeech/TTS/prepare.sh | 2 +- egs/ljspeech/TTS/vits/export-onnx.py | 40 ++++++------- egs/ljspeech/TTS/vits/generator.py | 2 +- egs/ljspeech/TTS/vits/infer.py | 11 ++++ egs/ljspeech/TTS/vits/test_model.py | 50 ++++++++++++++++ egs/ljspeech/TTS/vits/test_onnx.py | 27 +++++++-- egs/ljspeech/TTS/vits/text_encoder.py | 7 ++- egs/ljspeech/TTS/vits/tokenizer.py | 10 +++- egs/ljspeech/TTS/vits/train.py | 22 +++---- egs/ljspeech/TTS/vits/vits.py | 53 +++++++++++++++- 12 files changed, 265 insertions(+), 49 deletions(-) create mode 100755 egs/ljspeech/TTS/vits/test_model.py diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 323d0adfc..d31bf6302 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -56,7 +56,8 @@ Training --start-epoch 1 \ --use-fp16 1 \ --exp-dir vits/exp \ - --tokens data/tokens.txt + --tokens data/tokens.txt \ + --model-type high \ --max-duration 500 .. note:: @@ -64,6 +65,11 @@ Training You can adjust the hyper-parameters to control the size of the VITS model and the training configurations. For more details, please run ``./vits/train.py --help``. +.. warning:: + + If you want a model that runs faster on CPU, please use ``--model-type low`` + or ``--model-type medium``. + .. note:: The training can take a long time (usually a couple of days). @@ -95,8 +101,8 @@ training part first. It will save the ground-truth and generated wavs to the dir Export models ------------- -Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: -``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. +Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``: +``vits-epoch-*.onnx``. .. code-block:: bash @@ -120,4 +126,7 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following link: - - ``_ + - ``--model-type=high``: ``_ + - ``--model-type=medium``: ``_ + - ``--model-type=low``: ``_ + diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 80be5a315..7b112c12c 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -1,10 +1,10 @@ # Introduction -This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. -A transcription is provided for each clip. +This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. +A transcription is provided for each clip. Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours. -The texts were published between 1884 and 1964, and are in the public domain. +The texts were published between 1884 and 1964, and are in the public domain. The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain. The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/). @@ -35,4 +35,69 @@ To inference, use: --exp-dir vits/exp \ --epoch 1000 \ --tokens data/tokens.txt -``` \ No newline at end of file +``` + +## Quality vs speed + +If you feel that the trained model is slow at runtime, you can specify the +argument `--model-type` during training. Possible values are: + + - `low`, means **low** quality. The resulting model is very small in file size + and runs very fast. The following is a wave file generatd by a `low` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633 + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``26.8 MB`` (float32). + + - `medium`, means **medium** quality. + The following is a wave file generatd by a `medium` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``70.9 MB`` (float32). + + - `high`, means **high** quality. This is the default value. + + The following is a wave file generatd by a `high` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``113 MB`` (float32). + + +A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at + + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +./vits/train.py \ + --world-size 4 \ + --num-epochs 1601 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --model-type low \ + --max-duration 800 +``` + +A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at + +```bash +export CUDA_VISIBLE_DEVICES=4,5,6,7 +./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp-medium \ + --model-type medium \ + --max-duration 500 + +# (Note it is killed after `epoch-820.pt`) +``` diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index cbf27bd42..bded423ac 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -54,7 +54,7 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare LJSpeech manifest" # We assume that you have downloaded the LJSpeech corpus - # to $dl_dir/LJSpeech + # to $dl_dir/LJSpeech-1.1 mkdir -p data/manifests if [ ! -e data/manifests/.ljspeech.done ]; then lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 58b166368..0740757c0 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -25,9 +25,8 @@ Export the model to ONNX: --exp-dir vits/exp \ --tokens data/tokens.txt -It will generate two files inside vits/exp: +It will generate one file inside vits/exp: - vits-epoch-1000.onnx - - vits-epoch-1000.int8.onnx (quantizated model) See ./test_onnx.py for how to use the exported ONNX models. """ @@ -40,7 +39,6 @@ from typing import Dict, Tuple import onnx import torch import torch.nn as nn -from onnxruntime.quantization import QuantType, quantize_dynamic from tokenizer import Tokenizer from train import get_model, get_params @@ -75,6 +73,16 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -136,7 +144,7 @@ class OnnxModel(nn.Module): Return a tuple containing: - audio, generated wavform tensor, (B, T_wav) """ - audio, _, _ = self.model.inference( + audio, _, _ = self.model.generator.inference( text=tokens, text_lengths=tokens_lens, noise_scale=noise_scale, @@ -198,6 +206,11 @@ def export_model_onnx( }, ) + if model.model.spks is None: + num_speakers = 1 + else: + num_speakers = model.model.spks + meta_data = { "model_type": "vits", "version": "1", @@ -206,8 +219,8 @@ def export_model_onnx( "language": "English", "voice": "en-us", # Choose your language appropriately "has_espeak": 1, - "n_speakers": 1, - "sample_rate": 22050, # Must match the real sample rate + "n_speakers": num_speakers, + "sample_rate": model.model.sampling_rate, # Must match the real sample rate } logging.info(f"meta_data: {meta_data}") @@ -233,14 +246,13 @@ def main(): load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model = model.generator model.to("cpu") model.eval() model = OnnxModel(model=model) num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"generator parameters: {num_param}") + logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") suffix = f"epoch-{params.epoch}" @@ -256,18 +268,6 @@ def main(): ) logging.info(f"Exported generator to {model_filename}") - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" - quantize_dynamic( - model_input=model_filename, - model_output=model_filename_int8, - weight_type=QuantType.QUInt8, - ) - if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index 66c8cedb1..b9add9e82 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module): self.upsample_factor = int(np.prod(decoder_upsample_scales)) self.spks = None if spks is not None and spks > 1: - assert global_channels > 0 + assert global_channels > 0, global_channels self.spks = spks self.global_emb = torch.nn.Embedding(spks, global_channels) self.spk_embed_dim = None diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 9e7c71c6d..7be76e315 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -72,6 +72,16 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -94,6 +104,7 @@ def infer_dataset( tokenizer: Used to convert text to phonemes. """ + # Background worker save audios to disk. def _save_worker( batch_size: int, diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py new file mode 100755 index 000000000..1de10f012 --- /dev/null +++ b/egs/ljspeech/TTS/vits/test_model.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tokenizer import Tokenizer +from train import get_model, get_params +from vits import VITS + + +def test_model_type(model_type): + tokens = "./data/tokens.txt" + + params = get_params() + + tokenizer = Tokenizer(tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_type = model_type + + model = get_model(params) + generator = model.generator + + num_param = sum([p.numel() for p in generator.parameters()]) + print( + f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M" + ) + + +def main(): + test_model_type("high") # 35.63 M + test_model_type("low") # 7.55 M + test_model_type("medium") # 23.61 M + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 4f46e8e6c..b3805fadb 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -54,6 +54,20 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--text", + type=str, + default="Ask not what your country can do for you; ask what you can do for your country.", + help="Text to generate speech for", + ) + + parser.add_argument( + "--output-filename", + type=str, + default="test_onnx.wav", + help="Filename to save the generated wave file.", + ) + return parser @@ -61,7 +75,7 @@ class OnnxModel: def __init__(self, model_filename: str): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 + session_opts.intra_op_num_threads = 1 self.session_opts = session_opts @@ -72,6 +86,9 @@ class OnnxModel: ) logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: """ Args: @@ -101,13 +118,14 @@ class OnnxModel: def main(): args = get_parser().parse_args() + logging.info(vars(args)) tokenizer = Tokenizer(args.tokens) logging.info("About to create onnx model") model = OnnxModel(args.model_filename) - text = "I went there to see the land, the people and how their system works, end quote." + text = args.text tokens = tokenizer.texts_to_token_ids( [text], intersperse_blank=True, add_sos=True, add_eos=True ) @@ -115,8 +133,9 @@ def main(): tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) audio = model(tokens, tokens_lens) # (1, T') - torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) - logging.info("Saved to test_onnx.wav") + output_filename = args.output_filename + torchaudio.save(output_filename, audio, sample_rate=model.sample_rate) + logging.info(f"Saved to {output_filename}") if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index fcbae7103..9b21ed9cb 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module): x_lengths (Tensor): Length tensor (B,). Returns: - Tensor: Encoded hidden representation (B, attention_dim, T_text). - Tensor: Projected mean tensor (B, attention_dim, T_text). - Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Encoded hidden representation (B, embed_dim, T_text). + Tensor: Projected mean tensor (B, embed_dim, T_text). + Tensor: Projected scale tensor (B, embed_dim, T_text). Tensor: Mask tensor for input tensor (B, 1, T_text). """ @@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module): # encoder assume the channel last (B, T_text, embed_dim) x = self.encoder(x, key_padding_mask=pad_mask) + # Note: attention_dim == embed_dim # convert the channel first (B, embed_dim, T_text) x = x.transpose(1, 2) diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 9a5a9090e..8144ffe1e 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -18,7 +18,15 @@ import logging from typing import Dict, List import tacotron_cleaner.cleaners -from piper_phonemize import phonemize_espeak + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease follow instructions in " + "../prepare.sh to install piper-phonemize" + ) + from utils import intersperse diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 6589b75ff..34b943765 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -153,6 +153,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -189,15 +199,6 @@ def get_params() -> AttributeDict: - feature_dim: The model input dim. It has to match the one used in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. """ params = AttributeDict( { @@ -278,6 +279,7 @@ def get_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, feature_dim=params.feature_dim, sampling_rate=params.sampling_rate, + model_type=params.model_type, mel_loss_params=mel_loss_params, lambda_adv=params.lambda_adv, lambda_mel=params.lambda_mel, @@ -363,7 +365,7 @@ def train_one_epoch( model.train() device = model.device if isinstance(model, DDP) else next(model.parameters()).device - # used to summary the stats over iterations in one epoch + # used to track the stats over iterations in one epoch tot_loss = MetricsTracker() saved_bad_model = False diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index b4f0c21e6..0b9575cbd 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -5,6 +5,7 @@ """VITS module for GAN-TTS task.""" +import copy from typing import Any, Dict, Optional, Tuple import torch @@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = { "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA } +LOW_CONFIG = { + "hidden_channels": 96, + "decoder_upsample_scales": (8, 8, 4), + "decoder_channels": 256, + "decoder_upsample_kernel_sizes": (16, 16, 8), + "decoder_resblock_kernel_sizes": (3, 5, 7), + "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), + "text_encoder_cnn_module_kernel": 3, +} + +MEDIUM_CONFIG = { + "hidden_channels": 192, + "decoder_upsample_scales": (8, 8, 4), + "decoder_channels": 256, + "decoder_upsample_kernel_sizes": (16, 16, 8), + "decoder_resblock_kernel_sizes": (3, 5, 7), + "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), + "text_encoder_cnn_module_kernel": 3, +} + +HIGH_CONFIG = { + "hidden_channels": 192, + "decoder_upsample_scales": (8, 8, 2, 2), + "decoder_channels": 512, + "decoder_upsample_kernel_sizes": (16, 16, 4, 4), + "decoder_resblock_kernel_sizes": (3, 7, 11), + "decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + "text_encoder_cnn_module_kernel": 5, +} + class VITS(nn.Module): """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" @@ -49,6 +80,7 @@ class VITS(nn.Module): feature_dim: int = 513, sampling_rate: int = 22050, generator_type: str = "vits_generator", + model_type: str = "", generator_params: Dict[str, Any] = { "hidden_channels": 192, "spks": None, @@ -155,12 +187,13 @@ class VITS(nn.Module): """Initialize VITS module. Args: - idim (int): Input vocabrary size. + idim (int): Input vocabulary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since VITS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. + model_type (str): If not empty, must be one of: low, medium, high generator_type (str): Generator type. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. @@ -181,6 +214,24 @@ class VITS(nn.Module): """ super().__init__() + generator_params = copy.deepcopy(generator_params) + discriminator_params = copy.deepcopy(discriminator_params) + generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params) + discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params) + feat_match_loss_params = copy.deepcopy(feat_match_loss_params) + mel_loss_params = copy.deepcopy(mel_loss_params) + + if model_type != "": + assert model_type in ("low", "medium", "high"), model_type + if model_type == "low": + generator_params.update(LOW_CONFIG) + elif model_type == "medium": + generator_params.update(MEDIUM_CONFIG) + elif model_type == "high": + generator_params.update(HIGH_CONFIG) + else: + raise ValueError(f"Unknown model_type: ${model_type}") + # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if generator_type == "vits_generator": From c3f6f28116b90f6f3ca5309f0ce06c318fdfc009 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 13 Mar 2024 10:01:28 +0800 Subject: [PATCH 124/216] Zipformer recipe for Cantonese dataset MDCC (#1537) * init commit * Create README.md * handle code switching cases * misc. fixes * added manifest statistics * init commit for the zipformer recipe * added scripts for exporting model * added RESULTS.md * added scripts for streaming related stuff * doc str fixed --- egs/aishell/ASR/README.md | 4 +- egs/aishell/ASR/prepare.sh | 2 +- egs/mdcc/ASR/README.md | 19 + egs/mdcc/ASR/RESULTS.md | 41 + egs/mdcc/ASR/local/compile_hlg.py | 1 + .../ASR/local/compile_hlg_using_openfst.py | 1 + egs/mdcc/ASR/local/compile_lg.py | 1 + egs/mdcc/ASR/local/compute_fbank_mdcc.py | 157 ++ .../ASR/local/display_manifest_statistics.py | 144 ++ egs/mdcc/ASR/local/prepare_char.py | 1 + .../local/prepare_char_lm_training_data.py | 1 + egs/mdcc/ASR/local/prepare_lang.py | 1 + egs/mdcc/ASR/local/prepare_lang_fst.py | 1 + egs/mdcc/ASR/local/preprocess_mdcc.py | 157 ++ egs/mdcc/ASR/local/text2segments.py | 86 ++ egs/mdcc/ASR/local/text2token.py | 1 + egs/mdcc/ASR/prepare.sh | 308 ++++ egs/mdcc/ASR/shared | 1 + egs/mdcc/ASR/zipformer/__init__.py | 0 egs/mdcc/ASR/zipformer/asr_datamodule.py | 382 +++++ egs/mdcc/ASR/zipformer/beam_search.py | 1 + egs/mdcc/ASR/zipformer/decode.py | 813 ++++++++++ egs/mdcc/ASR/zipformer/decode_stream.py | 1 + egs/mdcc/ASR/zipformer/decoder.py | 1 + egs/mdcc/ASR/zipformer/encoder_interface.py | 1 + egs/mdcc/ASR/zipformer/export-onnx-ctc.py | 1 + .../zipformer/export-onnx-streaming-ctc.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/mdcc/ASR/zipformer/export-onnx.py | 1 + egs/mdcc/ASR/zipformer/export.py | 1 + egs/mdcc/ASR/zipformer/joiner.py | 1 + egs/mdcc/ASR/zipformer/model.py | 1 + egs/mdcc/ASR/zipformer/onnx_check.py | 1 + egs/mdcc/ASR/zipformer/onnx_decode.py | 286 ++++ egs/mdcc/ASR/zipformer/optim.py | 1 + egs/mdcc/ASR/zipformer/scaling.py | 1 + egs/mdcc/ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + egs/mdcc/ASR/zipformer/streaming_decode.py | 881 +++++++++++ egs/mdcc/ASR/zipformer/subsampling.py | 1 + egs/mdcc/ASR/zipformer/train.py | 1345 +++++++++++++++++ egs/mdcc/ASR/zipformer/zipformer.py | 1 + requirements.txt | 7 +- 43 files changed, 4655 insertions(+), 4 deletions(-) create mode 100644 egs/mdcc/ASR/README.md create mode 100644 egs/mdcc/ASR/RESULTS.md create mode 120000 egs/mdcc/ASR/local/compile_hlg.py create mode 120000 egs/mdcc/ASR/local/compile_hlg_using_openfst.py create mode 120000 egs/mdcc/ASR/local/compile_lg.py create mode 100755 egs/mdcc/ASR/local/compute_fbank_mdcc.py create mode 100755 egs/mdcc/ASR/local/display_manifest_statistics.py create mode 120000 egs/mdcc/ASR/local/prepare_char.py create mode 120000 egs/mdcc/ASR/local/prepare_char_lm_training_data.py create mode 120000 egs/mdcc/ASR/local/prepare_lang.py create mode 120000 egs/mdcc/ASR/local/prepare_lang_fst.py create mode 100755 egs/mdcc/ASR/local/preprocess_mdcc.py create mode 100755 egs/mdcc/ASR/local/text2segments.py create mode 120000 egs/mdcc/ASR/local/text2token.py create mode 100755 egs/mdcc/ASR/prepare.sh create mode 120000 egs/mdcc/ASR/shared create mode 100644 egs/mdcc/ASR/zipformer/__init__.py create mode 100644 egs/mdcc/ASR/zipformer/asr_datamodule.py create mode 120000 egs/mdcc/ASR/zipformer/beam_search.py create mode 100755 egs/mdcc/ASR/zipformer/decode.py create mode 120000 egs/mdcc/ASR/zipformer/decode_stream.py create mode 120000 egs/mdcc/ASR/zipformer/decoder.py create mode 120000 egs/mdcc/ASR/zipformer/encoder_interface.py create mode 120000 egs/mdcc/ASR/zipformer/export-onnx-ctc.py create mode 120000 egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py create mode 120000 egs/mdcc/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/mdcc/ASR/zipformer/export-onnx.py create mode 120000 egs/mdcc/ASR/zipformer/export.py create mode 120000 egs/mdcc/ASR/zipformer/joiner.py create mode 120000 egs/mdcc/ASR/zipformer/model.py create mode 120000 egs/mdcc/ASR/zipformer/onnx_check.py create mode 100755 egs/mdcc/ASR/zipformer/onnx_decode.py create mode 120000 egs/mdcc/ASR/zipformer/optim.py create mode 120000 egs/mdcc/ASR/zipformer/scaling.py create mode 120000 egs/mdcc/ASR/zipformer/scaling_converter.py create mode 120000 egs/mdcc/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/mdcc/ASR/zipformer/streaming_decode.py create mode 120000 egs/mdcc/ASR/zipformer/subsampling.py create mode 100755 egs/mdcc/ASR/zipformer/train.py create mode 120000 egs/mdcc/ASR/zipformer/zipformer.py diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index b54719162..d088072a7 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -19,7 +19,9 @@ The following table lists the differences among them. | `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` | | `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data | | `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data| -| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 | +| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 | +| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 | + The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index b7be89bc8..13be69534 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then fi if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 11: Train RNN LM model" + log "Stage 12: Train RNN LM model" python ../../../icefall/rnn_lm/train.py \ --start-epoch 0 \ --world-size 1 \ diff --git a/egs/mdcc/ASR/README.md b/egs/mdcc/ASR/README.md new file mode 100644 index 000000000..112845b73 --- /dev/null +++ b/egs/mdcc/ASR/README.md @@ -0,0 +1,19 @@ +# Introduction + +Multi-Domain Cantonese Corpus (MDCC), consists of 73.6 hours of clean read speech paired with +transcripts, collected from Cantonese audiobooks from Hong Kong. It comprises philosophy, +politics, education, culture, lifestyle and family domains, covering a wide range of topics. + +Manuscript can be found at: https://arxiv.org/abs/2201.02419 + +# Transducers + + + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|-----------------------------| +| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 | + +The decoder is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/mdcc/ASR/RESULTS.md b/egs/mdcc/ASR/RESULTS.md new file mode 100644 index 000000000..ff7ddc957 --- /dev/null +++ b/egs/mdcc/ASR/RESULTS.md @@ -0,0 +1,41 @@ +## Results + +#### Zipformer + +See + +[./zipformer](./zipformer) + +##### normal-scaled model, number of model parameters: 74470867, i.e., 74.47 M + +| | test | valid | comment | +|------------------------|------|-------|-----------------------------------------| +| greedy search | 7.45 | 7.51 | --epoch 45 --avg 35 | +| modified beam search | 6.68 | 6.73 | --epoch 45 --avg 35 | +| fast beam search | 7.22 | 7.28 | --epoch 45 --avg 35 | + +The training command: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer/train.py \ + --world-size 4 \ + --start-epoch 1 \ + --num-epochs 50 \ + --use-fp16 1 \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 +``` + +The decoding command: + +``` + ./zipformer/decode.py \ + --epoch 45 \ + --avg 35 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search # modified_beam_search +``` + +The pretrained model is available at: https://huggingface.co/zrjin/icefall-asr-mdcc-zipformer-2024-03-11/ \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_hlg.py b/egs/mdcc/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/mdcc/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_hlg_using_openfst.py b/egs/mdcc/ASR/local/compile_hlg_using_openfst.py new file mode 120000 index 000000000..d34edd7f3 --- /dev/null +++ b/egs/mdcc/ASR/local/compile_hlg_using_openfst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg_using_openfst.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_lg.py b/egs/mdcc/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/mdcc/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compute_fbank_mdcc.py b/egs/mdcc/ASR/local/compute_fbank_mdcc.py new file mode 100755 index 000000000..647b21127 --- /dev/null +++ b/egs/mdcc/ASR/local/compute_fbank_mdcc.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the aishell dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_mdcc( + num_mel_bins: int = 80, + perturb_speed: bool = False, + whisper_fbank: bool = False, + output_dir: str = "data/fbank", +): + src_dir = Path("data/manifests") + output_dir = Path(output_dir) + num_jobs = min(15, os.cpu_count()) + + dataset_parts = ( + "train", + "valid", + "test", + ) + prefix = "mdcc" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition and perturb_speed: + logging.info("Doing speed perturb") + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + compute_fbank_mdcc( + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, + ) diff --git a/egs/mdcc/ASR/local/display_manifest_statistics.py b/egs/mdcc/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..27cf8c943 --- /dev/null +++ b/egs/mdcc/ASR/local/display_manifest_statistics.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/fbank/mdcc_cuts_train.jsonl.gz" + path = "./data/fbank/mdcc_cuts_valid.jsonl.gz" + path = "./data/fbank/mdcc_cuts_test.jsonl.gz" + + cuts = load_manifest_lazy(path) + cuts.describe(full=True) + + +if __name__ == "__main__": + main() + +""" +data/fbank/mdcc_cuts_train.jsonl.gz (with speed perturbation) +_________________________________________ +_ Cuts count: _ 195360 +_________________________________________ +_ Total duration (hh:mm:ss) _ 173:44:59 +_________________________________________ +_ mean _ 3.2 +_________________________________________ +_ std _ 2.1 +_________________________________________ +_ min _ 0.2 +_________________________________________ +_ 25% _ 1.8 +_________________________________________ +_ 50% _ 2.7 +_________________________________________ +_ 75% _ 4.0 +_________________________________________ +_ 99% _ 11.0 _ +_________________________________________ +_ 99.5% _ 12.4 _ +_________________________________________ +_ 99.9% _ 14.8 _ +_________________________________________ +_ max _ 16.7 _ +_________________________________________ +_ Recordings available: _ 195360 _ +_________________________________________ +_ Features available: _ 195360 _ +_________________________________________ +_ Supervisions available: _ 195360 _ +_________________________________________ + +data/fbank/mdcc_cuts_valid.jsonl.gz +________________________________________ +_ Cuts count: _ 5663 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 05:03:12 _ +________________________________________ +_ mean _ 3.2 _ +________________________________________ +_ std _ 2.0 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 1.8 _ +________________________________________ +_ 50% _ 2.7 _ +________________________________________ +_ 75% _ 4.0 _ +________________________________________ +_ 99% _ 10.9 _ +________________________________________ +_ 99.5% _ 12.3 _ +________________________________________ +_ 99.9% _ 14.4 _ +________________________________________ +_ max _ 14.8 _ +________________________________________ +_ Recordings available: _ 5663 _ +________________________________________ +_ Features available: _ 5663 _ +________________________________________ +_ Supervisions available: _ 5663 _ +________________________________________ + +data/fbank/mdcc_cuts_test.jsonl.gz +________________________________________ +_ Cuts count: _ 12492 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 11:00:31 _ +________________________________________ +_ mean _ 3.2 _ +________________________________________ +_ std _ 2.0 _ +________________________________________ +_ min _ 0.2 _ +________________________________________ +_ 25% _ 1.8 _ +________________________________________ +_ 50% _ 2.7 _ +________________________________________ +_ 75% _ 4.0 _ +________________________________________ +_ 99% _ 10.5 _ +________________________________________ +_ 99.5% _ 12.1 _ +________________________________________ +_ 99.9% _ 14.0 _ +________________________________________ +_ max _ 14.8 _ +________________________________________ +_ Recordings available: _ 12492 _ +________________________________________ +_ Features available: _ 12492 _ +________________________________________ +_ Supervisions available: _ 12492 _ +________________________________________ + +""" diff --git a/egs/mdcc/ASR/local/prepare_char.py b/egs/mdcc/ASR/local/prepare_char.py new file mode 120000 index 000000000..42743b544 --- /dev/null +++ b/egs/mdcc/ASR/local/prepare_char.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_char_lm_training_data.py b/egs/mdcc/ASR/local/prepare_char_lm_training_data.py new file mode 120000 index 000000000..2374cafdd --- /dev/null +++ b/egs/mdcc/ASR/local/prepare_char_lm_training_data.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char_lm_training_data.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_lang.py b/egs/mdcc/ASR/local/prepare_lang.py new file mode 120000 index 000000000..bee8d5f03 --- /dev/null +++ b/egs/mdcc/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_lang_fst.py b/egs/mdcc/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/mdcc/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/preprocess_mdcc.py b/egs/mdcc/ASR/local/preprocess_mdcc.py new file mode 100755 index 000000000..cd0dc7de8 --- /dev/null +++ b/egs/mdcc/ASR/local/preprocess_mdcc.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a text file "data/lang_char/text" as input, the file consist of +lines each containing a transcript, applies text norm and generates the following +files in the directory "data/lang_char": + - text_norm + - words.txt + - words_no_ids.txt + - text_words_segmentation +""" + +import argparse +import logging +from pathlib import Path +from typing import List + +import pycantonese +from tqdm.auto import tqdm + +from icefall.utils import is_cjk + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Prepare char lexicon", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input-file", + "-i", + default="data/lang_char/text", + type=str, + help="The input text file", + ) + parser.add_argument( + "--output-dir", + "-o", + default="data/lang_char", + type=str, + help="The output directory", + ) + return parser + + +def get_norm_lines(lines: List[str]) -> List[str]: + def _text_norm(text: str) -> str: + # to cope with the protocol for transcription: + # When taking notes, the annotators adhere to the following guidelines: + # 1) If the audio contains pure music, the annotators mark the label + # "(music)" in the file name of its transcript. 2) If the utterance + # contains one or several sentences with background music or noise, the + # annotators mark the label "(music)" before each sentence in the transcript. + # 3) The annotators use {} symbols to enclose words they are uncertain + # about, for example, {梁佳佳},我是{}人. + + # here we manually fix some errors in the transcript + + return ( + text.strip() + .replace("(music)", "") + .replace("(music", "") + .replace("{", "") + .replace("}", "") + .replace("BB所以就指腹為親喇", "BB 所以就指腹為親喇") + .upper() + ) + + return [_text_norm(line) for line in lines] + + +def get_word_segments(lines: List[str]) -> List[str]: + # the current pycantonese segmenter does not handle the case when the input + # is code switching, so we need to handle it separately + + new_lines = [] + + for line in tqdm(lines, desc="Segmenting lines"): + try: + # code switching + if len(line.strip().split(" ")) > 1: + segments = [] + for segment in line.strip().split(" "): + if segment.strip() == "": + continue + try: + if not is_cjk(segment[0]): # en segment + segments.append(segment) + else: # zh segment + segments.extend(pycantonese.segment(segment)) + except Exception as e: + logging.error(f"Failed to process segment: {segment}") + raise e + new_lines.append(" ".join(segments) + "\n") + # not code switching + else: + new_lines.append(" ".join(pycantonese.segment(line)) + "\n") + except Exception as e: + logging.error(f"Failed to process line: {line}") + raise e + return new_lines + + +def get_words(lines: List[str]) -> List[str]: + words = set() + for line in tqdm(lines, desc="Getting words"): + words.update(line.strip().split(" ")) + return list(words) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + + input_file = Path(args.input_file) + output_dir = Path(args.output_dir) + + assert output_dir.is_dir(), f"{output_dir} does not exist" + assert input_file.is_file(), f"{input_file} does not exist" + + lines = input_file.read_text(encoding="utf-8").strip().split("\n") + + norm_lines = get_norm_lines(lines) + with open(output_dir / "text_norm", "w+", encoding="utf-8") as f: + f.writelines([line + "\n" for line in norm_lines]) + + text_words_segments = get_word_segments(norm_lines) + with open(output_dir / "text_words_segmentation", "w+", encoding="utf-8") as f: + f.writelines(text_words_segments) + + words = get_words(text_words_segments)[1:] # remove "\n" from words + with open(output_dir / "words_no_ids.txt", "w+", encoding="utf-8") as f: + f.writelines([word + "\n" for word in sorted(words)]) + + words = ( + ["", "!SIL", "", ""] + + sorted(words) + + ["#0", "", "<\s>"] + ) + + with open(output_dir / "words.txt", "w+", encoding="utf-8") as f: + f.writelines([f"{word} {i}\n" for i, word in enumerate(words)]) diff --git a/egs/mdcc/ASR/local/text2segments.py b/egs/mdcc/ASR/local/text2segments.py new file mode 100755 index 000000000..8ce7ab7e5 --- /dev/null +++ b/egs/mdcc/ASR/local/text2segments.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) +# 2022 Xiaomi Corp. (authors: Weiji Zhuang) +# 2024 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input "text", which refers to the transcript file for +MDCC: + - text +and generates the output file text_word_segmentation which is implemented +with word segmenting: + - text_words_segmentation +""" + +import argparse +from typing import List + +import pycantonese +from tqdm.auto import tqdm + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Cantonese Word Segmentation for text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input-file", + "-i", + default="data/lang_char/text", + type=str, + help="the input text file for MDCC", + ) + parser.add_argument( + "--output-file", + "-o", + default="data/lang_char/text_words_segmentation", + type=str, + help="the text implemented with words segmenting for MDCC", + ) + + return parser + + +def get_word_segments(lines: List[str]) -> List[str]: + return [ + " ".join(pycantonese.segment(line)) + "\n" + for line in tqdm(lines, desc="Segmenting lines") + ] + + +def main(): + parser = get_parser() + args = parser.parse_args() + + input_file = args.input_file + output_file = args.output_file + + with open(input_file, "r", encoding="utf-8") as fr: + lines = fr.readlines() + + new_lines = get_word_segments(lines) + + with open(output_file, "w", encoding="utf-8") as fw: + fw.writelines(new_lines) + + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/local/text2token.py b/egs/mdcc/ASR/local/text2token.py new file mode 120000 index 000000000..81e459d69 --- /dev/null +++ b/egs/mdcc/ASR/local/text2token.py @@ -0,0 +1 @@ +../../../aidatatang_200zh/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/mdcc/ASR/prepare.sh b/egs/mdcc/ASR/prepare.sh new file mode 100755 index 000000000..f4d9bc47e --- /dev/null +++ b/egs/mdcc/ASR/prepare.sh @@ -0,0 +1,308 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 +perturb_speed=true + + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/mdcc +# |-- README.md +# |-- audio/ +# |-- clip_info_rthk.csv +# |-- cnt_asr_metadata_full.csv +# |-- cnt_asr_test_metadata.csv +# |-- cnt_asr_train_metadata.csv +# |-- cnt_asr_valid_metadata.csv +# |-- data_statistic.py +# |-- length +# |-- podcast_447_2021.csv +# |-- test.txt +# |-- transcription/ +# `-- words_length +# You can download them from: +# https://drive.google.com/file/d/1epfYMMhXdBKA6nxPgUugb2Uj4DllSxkn/view?usp=drive_link +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "stage 0: Download data" + + # If you have pre-downloaded it to /path/to/mdcc, + # you can create a symlink + # + # ln -sfv /path/to/mdcc $dl_dir/mdcc + # + # The directory structure is + # mdcc/ + # |-- README.md + # |-- audio/ + # |-- clip_info_rthk.csv + # |-- cnt_asr_metadata_full.csv + # |-- cnt_asr_test_metadata.csv + # |-- cnt_asr_train_metadata.csv + # |-- cnt_asr_valid_metadata.csv + # |-- data_statistic.py + # |-- length + # |-- podcast_447_2021.csv + # |-- test.txt + # |-- transcription/ + # `-- words_length + + if [ ! -d $dl_dir/mdcc/audio ]; then + lhotse download mdcc $dl_dir + + # this will download and unzip dataset.zip to $dl_dir/ + + mv $dl_dir/dataset $dl_dir/mdcc + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/musan + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare MDCC manifest" + # We assume that you have downloaded the MDCC corpus + # to $dl_dir/mdcc + if [ ! -f data/manifests/.mdcc_manifests.done ]; then + log "Might take 40 minutes to traverse the directory." + mkdir -p data/manifests + lhotse prepare mdcc $dl_dir/mdcc data/manifests + touch data/manifests/.mdcc_manifests.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + if [ ! -f data/manifests/.musan_manifests.done ]; then + log "It may take 6 minutes" + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan_manifests.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for MDCC" + if [ ! -f data/fbank/.mdcc.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_mdcc.py --perturb-speed ${perturb_speed} + touch data/fbank/.mdcc.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + if [ ! -f data/fbank/.msuan.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_musan.py + touch data/fbank/.msuan.done + fi +fi + +lang_char_dir=data/lang_char +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare char based lang" + mkdir -p $lang_char_dir + + # Prepare text. + # Note: in Linux, you can install jq with the following command: + # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + # 2. chmod +x ./jq + # 3. cp jq /usr/bin + if [ ! -f $lang_char_dir/text ]; then + gunzip -c data/manifests/mdcc_supervisions_train.jsonl.gz \ + |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ + > $lang_char_dir/train_text + + cat $lang_char_dir/train_text > $lang_char_dir/text + + gunzip -c data/manifests/mdcc_supervisions_test.jsonl.gz \ + |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ + > $lang_char_dir/valid_text + + cat $lang_char_dir/valid_text >> $lang_char_dir/text + + gunzip -c data/manifests/mdcc_supervisions_valid.jsonl.gz \ + |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ + > $lang_char_dir/test_text + + cat $lang_char_dir/test_text >> $lang_char_dir/text + fi + + if [ ! -f $lang_char_dir/text_words_segmentation ]; then + ./local/preprocess_mdcc.py --input-file $lang_char_dir/text \ + --output-dir $lang_char_dir + + mv $lang_char_dir/text $lang_char_dir/_text + cp $lang_char_dir/text_words_segmentation $lang_char_dir/text + fi + + if [ ! -f $lang_char_dir/tokens.txt ]; then + ./local/prepare_char.py --lang-dir $lang_char_dir + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare G" + + mkdir -p data/lm + + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_char_dir/text_words_segmentation \ + -lm data/lm/3-gram.unpruned.arpa + fi + + # We assume you have installed kaldilm, if not, please install + # it using: pip install kaldilm + if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="$lang_char_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt + fi + + if [ ! -f $lang_char_dir/HLG.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_char_dir \ + --ngram-G ./data/lm/G_3_gram_char.fst.txt + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compile LG & HLG" + + ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char + ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Generate LM training data" + + log "Processing char based data" + out_dir=data/lm_training_char + mkdir -p $out_dir $dl_dir/lm + + if [ ! -f $dl_dir/lm/mdcc-train-word.txt ]; then + ./local/text2segments.py --input-file $lang_char_dir/train_text \ + --output-file $dl_dir/lm/mdcc-train-word.txt + fi + + # training words + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/mdcc-train-word.txt \ + --lm-archive $out_dir/lm_data.pt + + # valid words + if [ ! -f $dl_dir/lm/mdcc-valid-word.txt ]; then + ./local/text2segments.py --input-file $lang_char_dir/valid_text \ + --output-file $dl_dir/lm/mdcc-valid-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/mdcc-valid-word.txt \ + --lm-archive $out_dir/lm_data_valid.pt + + # test words + if [ ! -f $dl_dir/lm/mdcc-test-word.txt ]; then + ./local/text2segments.py --input-file $lang_char_dir/test_text \ + --output-file $dl_dir/lm/mdcc-test-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/mdcc-test-word.txt \ + --lm-archive $out_dir/lm_data_test.pt +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of tokens + # in a sentence. + + out_dir=data/lm_training_char + mkdir -p $out_dir + ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 1 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 512 \ + --hidden-dim 512 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data $out_dir/sorted_lm_data.pt \ + --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12345 +fi diff --git a/egs/mdcc/ASR/shared b/egs/mdcc/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/mdcc/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/__init__.py b/egs/mdcc/ASR/zipformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/mdcc/ASR/zipformer/asr_datamodule.py b/egs/mdcc/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..1f49b6520 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,382 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 Xiaomi Corporation (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class MdccAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders( + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures() + ), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest_lazy( + self.args.manifest_dir / "mdcc_cuts_train.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get valid cuts") + return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_valid.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_test.jsonl.gz") diff --git a/egs/mdcc/ASR/zipformer/beam_search.py b/egs/mdcc/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/mdcc/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/decode.py b/egs/mdcc/ASR/zipformer/decode.py new file mode 100755 index 000000000..ce104baf7 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/decode.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Mingshuang Luo, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import MdccAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MdccAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + mdcc = MdccAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + valid_cuts = mdcc.valid_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = mdcc.valid_dataloaders(valid_cuts) + + test_cuts = mdcc.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = mdcc.test_dataloaders(test_cuts) + + test_sets = ["valid", "test"] + test_dls = [valid_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/zipformer/decode_stream.py b/egs/mdcc/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/decoder.py b/egs/mdcc/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/encoder_interface.py b/egs/mdcc/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-ctc.py b/egs/mdcc/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 000000000..f9d756352 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 120000 index 000000000..652346001 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-streaming.py b/egs/mdcc/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx.py b/egs/mdcc/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export.py b/egs/mdcc/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/joiner.py b/egs/mdcc/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/model.py b/egs/mdcc/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/onnx_check.py b/egs/mdcc/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/onnx_decode.py b/egs/mdcc/ASR/zipformer/onnx_decode.py new file mode 100755 index 000000000..1ed4a9fa1 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/onnx_decode.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. +""" + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import MdccAsrDataModule +from lhotse.cut import Cut +from onnx_pretrained import OnnxModel, greedy_search + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: k2.SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + Mapping ids to tokens. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [[token_table[h] for h in hyp] for hyp in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: k2.SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + Mapping ids to tokens. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = list(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MdccAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = k2.SymbolTable.from_file(args.tokens) + assert token_table[0] == "" + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + + mdcc = MdccAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + valid_cuts = mdcc.valid_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = mdcc.valid_dataloaders(valid_cuts) + + test_cuts = mdcc.test_net_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = mdcc.test_dataloaders(test_cuts) + + test_sets = ["valid", "test"] + test_dl = [valid_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/zipformer/optim.py b/egs/mdcc/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/mdcc/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/scaling.py b/egs/mdcc/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/scaling_converter.py b/egs/mdcc/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/mdcc/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/streaming_beam_search.py b/egs/mdcc/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/streaming_decode.py b/egs/mdcc/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..dadb0b55f --- /dev/null +++ b/egs/mdcc/ASR/zipformer/streaming_decode.py @@ -0,0 +1,881 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +from asr_datamodule import MdccAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="Path to the lang dir(containing lexicon, tokens, etc.)", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + encoder_out=encoder_out, + streams=decode_streams, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + lexicon: + The Lexicon. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + list(decode_streams[i].ground_truth.strip()), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + key = f"greedy_search_{key}" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_{key}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}_{key}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MdccAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + mdcc = MdccAsrDataModule(args) + + valid_cuts = mdcc.valid_cuts() + test_cuts = mdcc.test_cuts() + + test_sets = ["valid", "test"] + test_cuts = [valid_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + ) + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/zipformer/subsampling.py b/egs/mdcc/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/mdcc/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py new file mode 100755 index 000000000..2fae66844 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/train.py @@ -0,0 +1,1345 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --exp-dir zipformer/exp \ + --max-duration 350 + +# For mix precision training: + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import MdccAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="""Maximum left-contexts for causal training, measured in frames which will + be converted to a number of chunks. If splitting into chunks, + chunk left-context frames will be chosen randomly from this list; else not relevant.""", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + mdcc = MdccAsrDataModule(args) + + train_cuts = mdcc.train_cuts() + valid_cuts = mdcc.valid_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = mdcc.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) + + valid_dl = mdcc.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + The compiler to encode texts to ids. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = supervisions["text"] + y = graph_compiler.texts_to_ids(texts) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MdccAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/zipformer/zipformer.py b/egs/mdcc/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e64afd1ee..6bafa6aca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -kaldifst +kaldifst>1.7.0 kaldilm kaldialign num2words @@ -14,4 +14,7 @@ onnxruntime==1.16.3 # style check session: black==22.3.0 isort==5.10.1 -flake8==5.0.4 \ No newline at end of file +flake8==5.0.4 + +# cantonese word segment support +pycantonese==3.4.0 \ No newline at end of file From d406b41cbda16df57f4cbe0ee091032a1df7fce3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 13 Mar 2024 11:01:18 +0800 Subject: [PATCH 125/216] Doc: Add page for installing piper-phonemize (#1547) --- .github/scripts/.gitignore | 1 + .../scripts/generate-piper-phonemize-page.py | 29 +++++++++++++++++++ .github/workflows/build-doc.yml | 3 ++ 3 files changed, 33 insertions(+) create mode 100644 .github/scripts/.gitignore create mode 100755 .github/scripts/generate-piper-phonemize-page.py diff --git a/.github/scripts/.gitignore b/.github/scripts/.gitignore new file mode 100644 index 000000000..672e477d8 --- /dev/null +++ b/.github/scripts/.gitignore @@ -0,0 +1 @@ +piper_phonemize.html diff --git a/.github/scripts/generate-piper-phonemize-page.py b/.github/scripts/generate-piper-phonemize-page.py new file mode 100755 index 000000000..3784d5fa5 --- /dev/null +++ b/.github/scripts/generate-piper-phonemize-page.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + + +def main(): + prefix = ( + "https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/" + ) + files = [ + "piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + ] + with open("piper_phonemize.html", "w") as f: + for file in files: + url = prefix + file + f.write(f'{file}
\n') + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index d7fe2c964..c622476f2 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -56,11 +56,14 @@ jobs: - name: Build doc shell: bash run: | + .github/scripts/generate-piper-phonemize-page.py cd docs python3 -m pip install -r ./requirements.txt make html touch build/html/.nojekyll + cp -v ../piper_phonemize.html ./build/html/ + - name: Deploy uses: peaceiris/actions-gh-pages@v3 with: From 15bd9a841e347a8881fc6df599fd440ebb118da4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 13 Mar 2024 17:39:01 +0800 Subject: [PATCH 126/216] add CI for ljspeech (#1548) --- .github/scripts/librispeech/ASR/run.sh | 6 +- .github/scripts/ljspeech/TTS/run.sh | 157 ++++++++++++++++++ .github/workflows/ljspeech.yml | 102 ++++++++++++ docs/source/recipes/TTS/ljspeech/vits.rst | 69 ++++++++ .../TTS/local/prepare_tokens_ljspeech.py | 6 +- egs/ljspeech/TTS/prepare.sh | 5 +- .../TTS/vits/monotonic_align/__init__.py | 6 +- egs/ljspeech/TTS/vits/tokenizer.py | 4 +- egs/ljspeech/TTS/vits/tts_datamodule.py | 2 + 9 files changed, 347 insertions(+), 10 deletions(-) create mode 100755 .github/scripts/ljspeech/TTS/run.sh create mode 100644 .github/workflows/ljspeech.yml diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh index 7e9bd8a47..293ed66e5 100755 --- a/.github/scripts/librispeech/ASR/run.sh +++ b/.github/scripts/librispeech/ASR/run.sh @@ -15,9 +15,9 @@ function prepare_data() { # cause OOM error for CI later. mkdir -p download/lm pushd download/lm - wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt - wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt - wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz + wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz + wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt + wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt ls -lh gunzip librispeech-lm-norm.txt.gz diff --git a/.github/scripts/ljspeech/TTS/run.sh b/.github/scripts/ljspeech/TTS/run.sh new file mode 100755 index 000000000..707361782 --- /dev/null +++ b/.github/scripts/ljspeech/TTS/run.sh @@ -0,0 +1,157 @@ +#!/usr/bin/env bash + +set -ex + +python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html +python3 -m pip install espnet_tts_frontend +python3 -m pip install numba + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/ljspeech/TTS + +sed -i.bak s/600/8/g ./prepare.sh +sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh +sed -i.bak s/500/5/g ./prepare.sh +git diff + +function prepare_data() { + # We have created a subset of the data for testing + # + mkdir download + pushd download + wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2 + tar xvf LJSpeech-1.1.tar.bz2 + popd + + ./prepare.sh + tree . +} + +function train() { + pushd ./vits + sed -i.bak s/200/3/g ./train.py + git diff . + popd + + for t in low medium high; do + ./vits/train.py \ + --exp-dir vits/exp-$t \ + --model-type $t \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh vits/exp-$t + done +} + +function infer() { + for t in low medium high; do + ./vits/infer.py \ + --num-buckets 2 \ + --model-type $t \ + --epoch 1 \ + --exp-dir ./vits/exp-$t \ + --tokens data/tokens.txt \ + --max-duration 20 + done +} + +function export_onnx() { + for t in low medium high; do + ./vits/export-onnx.py \ + --model-type $t \ + --epoch 1 \ + --exp-dir ./vits/exp-$t \ + --tokens data/tokens.txt + + ls -lh vits/exp-$t/ + done +} + +function test_medium() { + git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12 + + ./vits/export-onnx.py \ + --model-type medium \ + --epoch 820 \ + --exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \ + --tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt + + ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp + + ./vits/test_onnx.py \ + --model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \ + --tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \ + --output-filename /icefall/test-medium.wav + + ls -lh /icefall/test-medium.wav + + d=/icefall/vits-icefall-en_US-ljspeech-medium + mkdir $d + cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/ + cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx + + rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12 + + pushd $d + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2 + tar xf espeak-ng-data.tar.bz2 + rm espeak-ng-data.tar.bz2 + cd .. + tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium + rm -rf vits-icefall-en_US-ljspeech-medium + ls -lh *.tar.bz2 + popd +} + +function test_low() { + git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12 + + ./vits/export-onnx.py \ + --model-type low \ + --epoch 1600 \ + --exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \ + --tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt + + ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp + + ./vits/test_onnx.py \ + --model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \ + --tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \ + --output-filename /icefall/test-low.wav + + ls -lh /icefall/test-low.wav + + d=/icefall/vits-icefall-en_US-ljspeech-low + mkdir $d + cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/ + cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx + + rm -rf icefall-tts-ljspeech-vits-low-2024-03-12 + + pushd $d + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2 + tar xf espeak-ng-data.tar.bz2 + rm espeak-ng-data.tar.bz2 + cd .. + tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low + rm -rf vits-icefall-en_US-ljspeech-low + ls -lh *.tar.bz2 + popd +} + +prepare_data +train +infer +export_onnx +rm -rf vits/exp-{low,medium,high} +test_medium +test_low diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml new file mode 100644 index 000000000..25402275b --- /dev/null +++ b/.github/workflows/ljspeech.yml @@ -0,0 +1,102 @@ +name: ljspeech + +on: + push: + branches: + - master + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: ljspeech-${{ github.ref }} + cancel-in-progress: true + +jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + ljspeech: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + ls -lh + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run tests + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + .github/scripts/ljspeech/TTS/run.sh + + - name: display files + shell: bash + run: | + ls -lh + + - uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + with: + name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} + path: ./*.wav + + - uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + with: + name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }} + path: ./*.wav + + - name: Release exported onnx models + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: vits-icefall-*.tar.bz2 + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: tts-models + diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index d31bf6302..9499a3aea 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -13,6 +13,14 @@ with the `LJSpeech `_ dataset. The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ +Install extra dependencies +-------------------------- + +.. code-block:: bash + + pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html + pip install numba espnet_tts_frontend + Data preparation ---------------- @@ -130,3 +138,64 @@ by visiting the following link: - ``--model-type=medium``: ``_ - ``--model-type=low``: ``_ +Usage in sherpa-onnx +-------------------- + +The following describes how to test the exported ONNX model in `sherpa-onnx`_. + +.. hint:: + + `sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python, + Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS. + + We only describe how to use pre-built binaries from `sherpa-onnx`_ below. + Please refer to ``_ + for more documentation. + +Install sherpa-onnx +^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + pip install sherpa-onnx + +To check that you have installed `sherpa-onnx`_ successfully, please run: + +.. code-block:: bash + + which sherpa-onnx-offline-tts + sherpa-onnx-offline-tts --help + +Download lexicon files +^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + cd /tmp + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2 + tar xf espeak-ng-data.tar.bz2 + +Run sherpa-onnx +^^^^^^^^^^^^^^^ + +.. code-block:: bash + + cd egs/ljspeech/TTS + + sherpa-onnx-offline-tts \ + --vits-model=vits/exp/vits-epoch-1000.onnx \ + --vits-tokens=data/tokens.txt \ + --vits-data-dir=/tmp/espeak-ng-data \ + --num-threads=1 \ + --output-filename=./high.wav \ + "Ask not what your country can do for you; ask what you can do for your country." + +.. hint:: + + You can also use ``sherpa-onnx-offline-tts-play`` to play the audio + as it is generating. + +You should get a file ``high.wav`` after running the above command. + +Congratulations! You have successfully trained and exported a text-to-speech +model and run it with `sherpa-onnx`_. diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py index 08fe7430e..4ba88604c 100755 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -23,7 +23,11 @@ This file reads the texts in given manifest and save the new cuts with phoneme t import logging from pathlib import Path -import tacotron_cleaner.cleaners +try: + import tacotron_cleaner.cleaners +except ModuleNotFoundError as ex: + raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n") + from lhotse import CutSet, load_manifest from piper_phonemize import phonemize_espeak diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index bded423ac..9ed0f93fd 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -28,7 +28,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "Stage -1: build monotonic_align lib" if [ ! -d vits/monotonic_align/build ]; then cd vits/monotonic_align - python setup.py build_ext --inplace + python3 setup.py build_ext --inplace cd ../../ else log "monotonic_align lib already built" @@ -82,8 +82,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare phoneme tokens for LJSpeech" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, - # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then ./local/prepare_tokens_ljspeech.py diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py index 2b35654f5..5dc3641e5 100644 --- a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py @@ -10,7 +10,11 @@ import warnings import numpy as np import torch -from numba import njit, prange + +try: + from numba import njit, prange +except ModuleNotFoundError as ex: + raise RuntimeError(f"{ex}/nPlease run\n pip install numba") try: from .core import maximum_path_c diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 8144ffe1e..3c9046add 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -23,8 +23,8 @@ try: from piper_phonemize import phonemize_espeak except Exception as ex: raise RuntimeError( - f"{ex}\nPlease follow instructions in " - "../prepare.sh to install piper-phonemize" + f"{ex}\nPlease run\n" + "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" ) from utils import intersperse diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 8ff868bc8..e1a9c7b3c 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -255,6 +255,7 @@ class LJSpeechTtsDataModule: valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, shuffle=False, ) logging.info("About to create valid dataloader") @@ -294,6 +295,7 @@ class LJSpeechTtsDataModule: test_sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, shuffle=False, ) logging.info("About to create test dataloader") From eb132da00d9666082a65692a542c9c94db2a56b4 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 14 Mar 2024 11:33:49 +0800 Subject: [PATCH 127/216] additional instruction for the `grad_scale is too small` error (#1550) --- .../do_not_use_it_directly.py | 5 +- .../ASR/pruned_transducer_stateless7/train.py | 5 +- .../train.py | 5 +- .../do_not_use_it_directly.py | 5 +- .../train.py | 5 +- egs/aishell/ASR/zipformer/train.py | 5 +- egs/aishell/ASR/zipformer/train_bbpe.py | 5 +- .../pruned_transducer_stateless7/train.py | 5 +- .../ASR/pruned_transducer_stateless7/train.py | 5 +- egs/ami/SURT/dprnn_zipformer/train.py | 5 +- egs/ami/SURT/dprnn_zipformer/train_adapt.py | 5 +- .../ASR/pruned_transducer_stateless7/train.py | 5 +- .../do_not_use_it_directly.py | 4 +- .../finetune.py | 5 +- .../train.py | 5 +- .../do_not_use_it_directly.py | 5 +- .../train.py | 5 +- egs/gigaspeech/ASR/zipformer/train.py | 5 +- egs/gigaspeech/KWS/zipformer/finetune.py | 5 +- egs/gigaspeech/KWS/zipformer/train.py | 5 +- egs/libricss/SURT/dprnn_zipformer/train.py | 5 +- .../SURT/dprnn_zipformer/train_adapt.py | 5 +- egs/libriheavy/ASR/zipformer/train.py | 5 +- .../zipformer_prompt_asr/train_baseline.py | 5 +- .../train_bert_encoder.py | 11 ++--- .../pruned_transducer_stateless7/finetune.py | 5 +- .../ASR/pruned_transducer_stateless7/train.py | 5 +- .../pruned_transducer_stateless7_ctc/train.py | 5 +- .../train.py | 5 +- .../do_not_use_it_directly.py | 5 +- .../train.py | 5 +- .../train.py | 5 +- .../ASR/pruned_transducer_stateless8/train.py | 5 +- .../ASR/tiny_transducer_ctc/train.py | 5 +- egs/librispeech/ASR/zipformer/finetune.py | 5 +- egs/librispeech/ASR/zipformer/train.py | 5 +- .../ASR/zipformer_adapter/train.py | 5 +- egs/librispeech/ASR/zipformer_ctc/train.py | 5 +- egs/librispeech/ASR/zipformer_mmi/train.py | 5 +- egs/multi_zh-hans/ASR/zipformer/train.py | 5 +- egs/multi_zh_en/ASR/zipformer/train.py | 5 +- egs/spgispeech/ASR/zipformer/train.py | 5 +- .../train.py | 5 +- egs/tedlium3/ASR/zipformer/train.py | 5 +- egs/wenetspeech/ASR/zipformer/train.py | 5 +- egs/wenetspeech/KWS/zipformer/finetune.py | 5 +- egs/wenetspeech/KWS/zipformer/train.py | 5 +- .../ASR/pruned_transducer_stateless7/train.py | 5 +- icefall/err.py | 47 +++++++++++++++++++ 49 files changed, 145 insertions(+), 147 deletions(-) create mode 100644 icefall/err.py diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py index 6027273b2..058d0ff6b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py @@ -89,6 +89,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -881,9 +882,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error() if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index 9d9dd4288..2dc835f3b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -85,6 +85,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import ( @@ -878,9 +879,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 3858bafd7..811269989 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -78,6 +78,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -871,9 +872,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 0fba3b58f..6653d9d9c 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -78,6 +78,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -882,9 +883,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py index 2e1044658..f3b0f1e11 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py @@ -78,6 +78,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -881,9 +882,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index d381649e4..a25979226 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -86,6 +86,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import ( @@ -985,9 +986,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index a2bf96b29..0713c5787 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -83,6 +83,7 @@ from icefall.checkpoint import ( update_averaged_model, ) from icefall.dist import cleanup_dist, setup_dist +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -570,9 +571,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 8f09f1aa5..30879d8d2 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -70,6 +70,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -851,9 +852,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index 9b67141c0..d62cdadb7 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -69,6 +69,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -842,9 +843,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py index cd5fafc34..adc6a8495 100755 --- a/egs/ami/SURT/dprnn_zipformer/train.py +++ b/egs/ami/SURT/dprnn_zipformer/train.py @@ -75,6 +75,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1138,9 +1139,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py index 9f3b4425f..ac5b0dadc 100755 --- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py @@ -75,6 +75,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1129,9 +1130,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 4aedeffe4..4957c0c31 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -79,6 +79,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -871,9 +872,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 0426bc9a3..8e16cc0bf 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -889,9 +889,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise RuntimeError(f", exiting: {cur_grad_scale}") if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py index 3a10c5d81..47c4ed312 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -81,6 +81,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -965,9 +966,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index a9bc9c2a2..dc5d0a858 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -78,6 +78,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -888,9 +889,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 685f6ece6..6d256308c 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -81,6 +81,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -909,9 +910,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index 73fcd67aa..ef7ea9013 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -81,6 +81,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -908,9 +909,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index c5335562c..f0ad98147 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -89,6 +89,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1031,9 +1032,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py index 2cd7c868b..a7ba56127 100755 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -100,6 +100,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -371,9 +372,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index e7387dd39..a4d670169 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -89,6 +89,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1034,9 +1035,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py index 6598f8b5d..90d742e7c 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train.py +++ b/egs/libricss/SURT/dprnn_zipformer/train.py @@ -85,6 +85,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1169,9 +1170,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py index 1c1b0c28c..8c37430ec 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py @@ -81,6 +81,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1056,9 +1057,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index c97da4a11..8d4d9d067 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -93,6 +93,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1036,9 +1037,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index c8b20d021..93f7e1248 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -103,6 +103,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1051,9 +1052,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index 9822b99c1..2a2c206aa 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -117,6 +117,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, - context_dim=4 * 768 - if params.context_injection - else -1, # the output dim of text encoder + context_dim=( + 4 * 768 if params.context_injection else -1 + ), # the output dim of text encoder context_injection=params.context_injection, ) return joiner @@ -1398,9 +1399,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index a7a8ef149..e7546ec45 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -80,6 +80,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -976,9 +977,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index fac3706d2..436ec53b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -81,6 +81,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -878,9 +879,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index d8fa08372..b35e56abc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -81,6 +81,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -902,9 +903,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 25a1aa674..c2d877a93 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -77,6 +77,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -891,9 +892,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 9a6d2155b..8e239e322 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -80,6 +80,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -880,9 +881,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index e1bdce49d..8bd00bbef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -80,6 +80,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -879,9 +880,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index 1642ef4b7..da5e144c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -84,6 +84,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -946,9 +947,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 3f271c5b4..646f30ca1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -89,6 +89,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -946,9 +947,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 8920764cd..1bfd071de 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -66,6 +66,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import UniqLexicon from icefall.utils import ( @@ -883,9 +884,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 843d103cc..2f7ec0c17 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -92,6 +92,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1122,9 +1123,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3ccf7d2f1..1111d32ab 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -90,6 +90,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1021,9 +1022,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index e64c10e7a..6c55896a8 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -81,6 +81,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1125,9 +1126,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index 60990456d..60112a84e 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -62,6 +62,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -797,9 +798,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index dd8949523..c1785a328 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -79,6 +79,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon, UniqLexicon from icefall.mmi import LFMMILoss @@ -816,9 +817,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index c1bbd2ee8..447ca122f 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -89,6 +89,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1020,9 +1021,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index 310c8fe59..5dba584f7 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -89,6 +89,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1042,9 +1043,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py index 1709a2845..ed66ca29b 100755 --- a/egs/spgispeech/ASR/zipformer/train.py +++ b/egs/spgispeech/ASR/zipformer/train.py @@ -89,6 +89,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -1020,9 +1021,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py index aee3972cd..2108266ec 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -78,6 +78,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -870,9 +871,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 5ad01df27..14a44efb3 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -87,6 +87,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, @@ -985,9 +986,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index b1557dedb..3d3762916 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -86,6 +86,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import ( @@ -985,9 +986,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index 76df7e8d5..3ad16fd11 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -111,6 +111,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import ( @@ -525,9 +526,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 05acbd6a9..b5cf3359a 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -88,6 +88,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import ( @@ -1042,9 +1043,7 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index 8c53972fd..d24c27326 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -81,6 +81,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -854,9 +855,7 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + raise_grad_scale_is_too_small_error(cur_grad_scale) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] diff --git a/icefall/err.py b/icefall/err.py new file mode 100644 index 000000000..065e2a53d --- /dev/null +++ b/icefall/err.py @@ -0,0 +1,47 @@ +# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin,) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def raise_grad_scale_is_too_small_error(cur_grad_scale: float): + raise RuntimeError( + f""" + grad_scale is too small, exiting: {cur_grad_scale} + + ========================= NOTE ========================= + If you see this error, it means that the gradient scale is too small. + + The default base_lr is 0.045 / 0.05 (depends on which recipe you are + using), this is an empirical value obtained mostly using 4 * 32GB V100 + GPUs with a max_duration of approx. 1,000. + The proper value of base_lr may vary depending on the number of GPUs + and the value of max-duration you are using. + + To fix this issue, you may need to adjust the value of base_lr accordingly. + + We would suggest you to decrease the value of base_lr by 0.005 (e.g., + from 0.045 to 0.04), and try again. If the error still exists, you may + repeat the process until base_lr hits 0.02. (Note that this will lead to + certain loss of performance, but it should work. You can compensate this by + increasing the num_epochs.) + + If the error still exists, you could try to seek help by raising an issue, + with a detailed description of (a) your computational resources, (b) the + base_lr and (c) the max_duration you are using, (d) detailed configuration + of your model. + + ======================================================== + """ + ) From f28c05f4f51b271aa68811fb12358a301ff4ba0c Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 14 Mar 2024 12:18:49 +0800 Subject: [PATCH 128/216] Documentation for adapter fine-tuning (#1545) --- .../Finetune/adapter/finetune_adapter.rst | 225 ++++++++++++++++++ docs/source/recipes/Finetune/index.rst | 1 + 2 files changed, 226 insertions(+) create mode 100644 docs/source/recipes/Finetune/adapter/finetune_adapter.rst diff --git a/docs/source/recipes/Finetune/adapter/finetune_adapter.rst b/docs/source/recipes/Finetune/adapter/finetune_adapter.rst new file mode 100644 index 000000000..a94b008f6 --- /dev/null +++ b/docs/source/recipes/Finetune/adapter/finetune_adapter.rst @@ -0,0 +1,225 @@ +Finetune from a pre-trained Zipformer model with adapters +========================================================= + +This tutorial shows you how to fine-tune a pre-trained **Zipformer** +transducer model on a new dataset with adapters. +Adapters are compact and efficient module that can be integrated into a pre-trained model +to improve the model's performance on a new domain. Adapters are injected +between different modules in the well-trained neural network. During training, only the parameters +in the adapters will be updated. It achieves competitive performance +while requiring much less GPU memory than full fine-tuning. For more details about adapters, +please refer to the original `paper `_ for more details. + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe + +For illustration purpose, we fine-tune the Zipformer transducer model +pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your +own data for fine-tuning if you create a manifest for your new dataset. + +Data preparation +---------------- + +Please follow the instructions in the `GigaSpeech recipe `_ +to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial. + + +Model preparation +----------------- + +We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The +checkpoint of the model can be downloaded via the following command: + +.. code-block:: bash + + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + $ cd icefall-asr-librispeech-zipformer-2023-05-15/exp + $ git lfs pull --include "pretrained.pt" + $ ln -s pretrained.pt epoch-99.pt + $ cd ../data/lang_bpe_500 + $ git lfs pull --include bpe.model + $ cd ../../.. + +Before fine-tuning, let's test the model's WER on the new domain. The following command performs +decoding on the GigaSpeech test sets: + +.. code-block:: bash + + ./zipformer/decode_gigaspeech.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \ + --use-averaged-model 0 \ + --max-duration 1000 \ + --decoding-method greedy_search + +You should see the following numbers: + +.. code-block:: + + For dev, WER of different settings are: + greedy_search 20.06 best for dev + + For test, WER of different settings are: + greedy_search 19.27 best for test + + +Fine-tune with adapter +---------------------- + +We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``. +The original model parameters remain untouched during training and only the parameters of +the adapters are updated. The following command starts a fine-tuning experiment with adapters: + +.. code-block:: bash + + $ do_finetune=1 + $ use_adapters=1 + $ adapter_dim=8 + + $ ./zipformer_adapter/train.py \ + --world-size 2 \ + --num-epochs 20 \ + --start-epoch 1 \ + --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \ + --use-fp16 1 \ + --base-lr 0.045 \ + --use-adapters $use_adapters --adapter-dim $adapter_dim \ + --bpe-model data/lang_bpe_500/bpe.model \ + --do-finetune $do_finetune \ + --master-port 13022 \ + --finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \ + --max-duration 1000 + +The following arguments are related to fine-tuning: + +- ``--do-finetune`` + If True, do fine-tuning by initializing the model from a pre-trained checkpoint. + **Note that if you want to resume your fine-tuning experiment from certain epochs, you + need to set this to False.** + +- ``use-adapters`` + If adapters are used during fine-tuning. + +- ``--adapter-dim`` + The bottleneck dimension of the adapter module. Typically a small number. + +You should notice that in the training log, the total number of trainale parameters is shown: + +.. code-block:: + + 2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model) + +The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster +and requires less memory than full fine-tuning. + + +Decoding +-------- + +After training, let's test the WERs. To test the WERs on the GigaSpeech set, +you can execute the following command: + +.. code-block:: bash + + $ epoch=20 + $ avg=10 + $ use_adapters=1 + $ adapter_dim=8 + + % ./zipformer/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \ + --max-duration 600 \ + --use-adapters $use_adapters \ + --adapter-dim $adapter_dim \ + --decoding-method greedy_search + +You should see the following numbers: + +.. code-block:: + + For dev, WER of different settings are: + greedy_search 15.44 best for dev + + For test, WER of different settings are: + greedy_search 15.42 best for test + + +The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters. + +The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters +to keep the same performance of the original model: + +.. code-block:: bash + + $ epoch=20 + $ avg=1 + $ use_adapters=0 + $ adapter_dim=8 + + % ./zipformer/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \ + --max-duration 600 \ + --use-adapters $use_adapters \ + --adapter-dim $adapter_dim \ + --decoding-method greedy_search + + +.. code-block:: + + For dev, WER of different settings are: + greedy_search 2.23 best for test-clean + + For test, WER of different settings are: + greedy_search 4.96 best for test-other + +The numbers are the same as reported in `icefall `_. So adapter-based +fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain. + + +Export the model +---------------- + +After training, the model can be exported to ``onnx`` format easily using the following command: + +.. code-block:: bash + + $ use_adapters=1 + $ adapter_dim=16 + + $ ./zipformer_adapter/export-onnx.py \ + --tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 1 \ + --epoch 20 \ + --avg 10 \ + --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \ + --use-adapters $use_adapters \ + --adapter-dim $adapter_dim \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" \ No newline at end of file diff --git a/docs/source/recipes/Finetune/index.rst b/docs/source/recipes/Finetune/index.rst index e62b8980f..7f36d2687 100644 --- a/docs/source/recipes/Finetune/index.rst +++ b/docs/source/recipes/Finetune/index.rst @@ -13,3 +13,4 @@ data to improve the performance on new domains. :caption: Table of Contents from_supervised/finetune_zipformer + adapter/finetune_adapter From 2dfd5dbf8be72fce434445268e415bcdfa4b3028 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:19:23 +0800 Subject: [PATCH 129/216] Add LoRA for Zipformer (#1540) --- egs/librispeech/ASR/README.md | 1 + .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 10 +- .../ASR/zipformer_lora/asr_datamodule.py | 1 + .../ASR/zipformer_lora/beam_search.py | 1 + .../ASR/zipformer_lora/decode_gigaspeech.py | 1115 ++++++++ egs/librispeech/ASR/zipformer_lora/decoder.py | 1 + .../ASR/zipformer_lora/encoder_interface.py | 1 + egs/librispeech/ASR/zipformer_lora/export.py | 543 ++++ .../ASR/zipformer_lora/finetune.py | 1553 ++++++++++ egs/librispeech/ASR/zipformer_lora/joiner.py | 1 + egs/librispeech/ASR/zipformer_lora/model.py | 1 + egs/librispeech/ASR/zipformer_lora/optim.py | 1 + egs/librispeech/ASR/zipformer_lora/scaling.py | 2052 ++++++++++++++ .../ASR/zipformer_lora/scaling_converter.py | 1 + .../ASR/zipformer_lora/subsampling.py | 1 + egs/librispeech/ASR/zipformer_lora/train.py | 1398 +++++++++ .../ASR/zipformer_lora/zipformer.py | 2522 +++++++++++++++++ 17 files changed, 9196 insertions(+), 7 deletions(-) create mode 120000 egs/librispeech/ASR/zipformer_lora/asr_datamodule.py create mode 120000 egs/librispeech/ASR/zipformer_lora/beam_search.py create mode 100755 egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py create mode 120000 egs/librispeech/ASR/zipformer_lora/decoder.py create mode 120000 egs/librispeech/ASR/zipformer_lora/encoder_interface.py create mode 100755 egs/librispeech/ASR/zipformer_lora/export.py create mode 100755 egs/librispeech/ASR/zipformer_lora/finetune.py create mode 120000 egs/librispeech/ASR/zipformer_lora/joiner.py create mode 120000 egs/librispeech/ASR/zipformer_lora/model.py create mode 120000 egs/librispeech/ASR/zipformer_lora/optim.py create mode 100644 egs/librispeech/ASR/zipformer_lora/scaling.py create mode 120000 egs/librispeech/ASR/zipformer_lora/scaling_converter.py create mode 120000 egs/librispeech/ASR/zipformer_lora/subsampling.py create mode 100755 egs/librispeech/ASR/zipformer_lora/train.py create mode 100644 egs/librispeech/ASR/zipformer_lora/zipformer.py diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 5c5a76917..080f81c91 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -36,6 +36,7 @@ The following table lists the differences among them. | `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty | | `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe | | `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters | +| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 2ab051e83..814390ad6 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -479,18 +479,14 @@ class LibriSpeechAsrDataModule: @lru_cache() def gigaspeech_subset_small_cuts(self) -> CutSet: logging.info("About to get Gigaspeech subset-S cuts") - return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") @lru_cache() def gigaspeech_dev_cuts(self) -> CutSet: logging.info("About to get Gigaspeech dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") @lru_cache() def gigaspeech_test_cuts(self) -> CutSet: logging.info("About to get Gigaspeech test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py b/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/beam_search.py b/egs/librispeech/ASR/zipformer_lora/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py new file mode 100755 index 000000000..4d93a905f --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts() + + dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts) + test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_lora/decoder.py b/egs/librispeech/ASR/zipformer_lora/decoder.py new file mode 120000 index 000000000..cab465d2b --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/decoder.py @@ -0,0 +1 @@ +../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/encoder_interface.py b/egs/librispeech/ASR/zipformer_lora/encoder_interface.py new file mode 120000 index 000000000..aa5d0217a --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/export.py b/egs/librispeech/ASR/zipformer_lora/export.py new file mode 100755 index 000000000..d47666bef --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/export.py @@ -0,0 +1,543 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer_lora/export.py \ + --exp-dir ./zipformer_lora/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer_lora/export.py \ + --exp-dir ./zipformer_lora/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer_lora/export.py \ + --exp-dir ./zipformer_lora/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer_lora/export.py \ + --exp-dir ./zipformer_lora/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer_lora/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer_lora/decode.py \ + --exp-dir ./zipformer_lora/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zipformer_lora/decode.py` and `zipformer_lora/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer_lora/decode.py \ + --exp-dir ./zipformer_lora/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zipformer_lora/streaming_decode.py \ + --exp-dir ./zipformer_lora/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + +- streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_lora/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + # merge the LoRA weights + model.eval() + + params.use_lora = False + base_model = get_model(params) + + new_state_dict = {} + state_dict = model.state_dict() + param_names = base_model.state_dict().keys() + for k in param_names: + assert k in state_dict.keys() + new_state_dict[k] = state_dict[k] + + base_model.load_state_dict(new_state_dict, strict=True) + + model = base_model + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py new file mode 100755 index 000000000..0464cf65c --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -0,0 +1,1553 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Fine-tune without mux (i.e not mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 0 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +# Fine-tune without mux (i.e mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 1 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + 100000 + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--use-lora", type=str2bool, default=True, help="If use LoRA for fine-tune" + ) + + parser.add_argument( + "--lora-r", type=int, default=0, help="The bottleneck dimension of LoRA" + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + use_lora=params.use_lora, + lora_r=params.lora_r if params.use_lora else 0, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + for name, m in model.named_modules(): + if "lora" in name: + m.training = True + else: + m.training = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) + model.train() + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + else: + # resuming training + assert params.start_epoch > 1, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + # keep the original model untouched, only update the adapters + num_trainable = 0 + for name, p in model.named_parameters(): + if "lora_A" in name or "lora_B" in name: + p.requires_grad = True + num_trainable += p.numel() + else: + p.requires_grad = False + + logging.info( + "A total of {} trainable parameters ({:.3f}% of the whole model)".format( + num_trainable, num_trainable / num_param * 100 + ) + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts() + if params.use_mux: + librispeech_cuts = librispeech.train_all_shuf_cuts() + train_cuts = CutSet.mux( + gigaspeech_cuts, # num cuts = 688182 + librispeech_cuts, # num cuts = 843723 + weights=[688182, 843723], + stop_early=True, + ) + else: + train_cuts = gigaspeech_cuts + logging.info(train_cuts) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + + valid_sets = ["librispeech", "gigaspeech"] + valid_dls = [ + librispeech.valid_dataloaders(valid_cuts), + librispeech.valid_dataloaders(gigaspeech_dev_cuts), + ] + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_lora/joiner.py b/egs/librispeech/ASR/zipformer_lora/joiner.py new file mode 120000 index 000000000..444cb5f15 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/joiner.py @@ -0,0 +1 @@ +../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/model.py b/egs/librispeech/ASR/zipformer_lora/model.py new file mode 120000 index 000000000..0c6fe6112 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/model.py @@ -0,0 +1 @@ +../zipformer/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/optim.py b/egs/librispeech/ASR/zipformer_lora/optim.py new file mode 120000 index 000000000..207eecfcd --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py new file mode 100644 index 000000000..3149db9f3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -0,0 +1,2052 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +import random +from typing import Optional, Tuple, Union + +import k2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + + def __init__(self, *args): + assert len(args) >= 1, len(args) + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [(float(x), float(y)) for x, y in args] + for x, y in self.pairs: + assert isinstance(x, (float, int)), type(x) + assert isinstance(y, (float, int)), type(y) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, (float, int)): + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear( + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def max(self, x): + if isinstance(x, (float, int)): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise linear + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p crosss. + """ + assert isinstance(p, PiecewiseLinear), type(p) + + # get sorted x-values without repetition. + x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): + # if the two lines in this subsegment potentially cross each other.. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specify the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or not in training mode or in + torch.jit scripting mode. + """ + + def __init__(self, *args, default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) + + def __float__(self): + batch_count = self.batch_count + if ( + batch_count is None + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, default=self.default) + else: + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), default=self.default) + else: + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) + + +FloatLike = Union[float, ScheduledFloat] + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + + p is the proportion of items that should be above the cutoff. + """ + + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = x > self.cutoff + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1 - q) + return ans + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim=dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BiasNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return x * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + +class BiasNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interpreted as an offset from the input's ndim if negative. + This is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, + ) -> None: + super(BiasNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + + self.store_output_for_backprop = store_output_for_backprop + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() + return x * scales + + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) + + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +class LoRALayer: + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + +class ScaledLinear_lora(nn.Linear, LoRALayer): + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + fan_in_fan_out: bool = False, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + initial_scale: float = 1.0, + merge_weights: bool = True, + **kwargs, + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__( + self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights, + ) + + self.initial_scale = initial_scale + self.fan_in_fan_out = fan_in_fan_out + if r > 0: + self.lora_A = nn.Parameter(torch.full((r, in_features), 0.0)) + self.lora_B = nn.Parameter(torch.full((out_features, r), 0.0)) + self.scaling = self.lora_alpha / self.r + self.weight.requires_grad = False + + self.reset_parameters() + + def reset_parameters(self): + # initialize the parameters + nn.Linear.reset_parameters(self) + if hasattr(self, "lora_A"): + initial_scale = self.initial_scale + with torch.no_grad(): + self.weight[:] *= initial_scale + if self.bias is not None: + nn.init.uniform_( + self.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + if hasattr(self, "lora_A"): + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def train(self, mode: bool = True): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + nn.Linear.train(self, mode) + if mode: + # We don't want the weights to be merged in training mode + if self.merge_weights and self.merged: + if self.r > 0: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + else: + # When evaluating the model, we merge the weights for simplicity + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + delta_result = ( + self.lora_dropout(x) + @ self.lora_A.transpose(0, 1) + @ self.lora_B.transpose(0, 1) + ) + return result + delta_result * self.scaling + else: + return F.linear(x, T(self.weight), bias=self.bias) + + +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv2d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) + + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: + """ + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + # half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0 or chunk_size > seq_len: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Streaming Forward function. + + Args: + x: a Tensor of shape (batch_size, channels, seq_len) + cache: cached left context of shape (batch_size, channels, left_pad) + """ + (batch_size, num_channels, seq_len) = x.shape + + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + # Pad cache + assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[..., -left_pad:] + + x_causal = self.causal_conv(x) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size=seq_len) + x_chunk = x_chunk * chunk_scale + + return x_chunk + x_causal, cache + + +class BalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + ctx.save_for_backward(x) + ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors + (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = m_loss + r_loss + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + except Exception as e: + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) + + return x_grad, None, None, None, None, None, None + + +class Balancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, + ): + super().__init__() + + if prob is None: + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) + self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) + + # actually self.num_channels is no longer needed except for an assertion. + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.min_abs = min_abs + self.max_abs = max_abs + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): + return _no_op(x) + + prob = float(self.prob) + if random.random() < prob: + # The following inner-functions convert from the way we historically specified + # these limitations, as limits on the absolute value and the proportion of positive + # values, to limits on the RMS value and the (mean / stddev). + def _abs_to_rms(x): + # for normally distributed data, if the expected absolute value is x, the + # expected rms value will be sqrt(pi/2) * x. + return 1.25331413732 * x + + def _proportion_positive_to_mean(x): + def _atanh(x): + eps = 1.0e-10 + # eps is to prevent crashes if x is exactly 0 or 1. + # we'll just end up returning a fairly large value. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 + + def _approx_inverse_erf(x): + # 1 / (sqrt(pi) * ln(2)), + # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions + # this approximation is extremely crude and gets progressively worse for + # x very close to -1 or +1, but we mostly care about the "middle" region + # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, + # and math.erf(0.0407316414078772) = 0.045935330944660666, + # which is pretty close to 0.05. + return 0.8139535143 * _atanh(x) + + # first convert x from the range 0..1 to the range -1..1 which the error + # function returns + x = -1 + (2 * x) + return _approx_inverse_erf(x) + + min_mean = _proportion_positive_to_mean(float(self.min_positive)) + max_mean = _proportion_positive_to_mean(float(self.max_positive)) + min_rms = _abs_to_rms(float(self.min_abs)) + max_rms = _abs_to_rms(float(self.max_abs)) + grad_scale = float(self.grad_scale) + + assert x.shape[self.channel_dim] == self.num_channels + + return BalancerFunction.apply( + x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: + ctx.save_for_backward(x) + ctx.module = module + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + w = ctx.module + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, w.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) + + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None + except Exception as e: + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) + return x_grad, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert float(whitening_limit) >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + if isinstance(prob, float): + prob = (prob, prob) + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob <= self.max_prob <= 1 + self.prob = self.max_prob + self.name = None # will be set in training loop + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + grad_scale = float(self.grad_scale) + if not x.requires_grad or random.random() > self.prob or grad_scale == 0: + return _no_op(x) + else: + return WhiteningPenaltyFunction.apply(x, self) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x, y, name): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return x + return scale_grad(x, self.alpha) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x,) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.044 + ceil = 1.2 + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) + + +class MulForDropout3(torch.autograd.Function): + # returns (x * y * alpha) where alpha is a float and y doesn't require + # grad and is zero-or-one. + @staticmethod + @custom_fwd + def forward(ctx, x, y, alpha): + assert not y.requires_grad + ans = x * y * alpha + ctx.save_for_backward(ans) + ctx.alpha = alpha + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + (ans,) = ctx.saved_tensors + x_grad = ctx.alpha * ans_grad * (ans != 0) + return x_grad, None, None + + +# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, +# and it lets you choose one dimension to share the dropout mask over +class Dropout3(nn.Module): + def __init__(self, p: FloatLike, shared_dim: int): + super().__init__() + self.p = p + self.shared_dim = shared_dim + + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + scale = 1.0 / (1 - p) + rand_shape = list(x.shape) + rand_shape[self.shared_dim] = 1 + mask = torch.rand(*rand_shape, device=x.device) > p + ans = MulForDropout3.apply(x, mask, scale) + return ans + + +class SwooshLFunction(torch.autograd.Function): + """ + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + coeff = -0.08 + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + + if not requires_grad: + return y + + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = coeff + ceil = 1.0 + coeff + 0.005 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshL(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + if not x.requires_grad: + return k2.swoosh_l_forward(x) + else: + return k2.swoosh_l(x) + # return SwooshLFunction.apply(x) + + +class SwooshLOnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshR(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + if not x.requires_grad: + return k2.swoosh_r_forward(x) + else: + return k2.swoosh_r(x) + # return SwooshRFunction.apply(x) + + +class SwooshROnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 + + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +class ActivationDropoutAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): + if dropout_p != 0.0: + dropout_shape = list(x.shape) + if dropout_shared_dim is not None: + dropout_shape[dropout_shared_dim] = 1 + # else it won't be very memory efficient. + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) + else: + dropout_mask = None + + ctx.save_for_backward(x, weight, bias, dropout_mask) + + ctx.activation = activation + + forward_activation_dict = { + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, + } + # it will raise a KeyError if this fails. This will be an error. We let it + # propagate to the user. + activation_func = forward_activation_dict[activation] + x = activation_func(x) + if dropout_mask is not None: + x = x * dropout_mask + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias, dropout_mask) = saved + + forward_and_deriv_activation_dict = { + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, + } + # the following lines a KeyError if the activation is unrecognized. + # This will be an error. We let it propagate to the user. + func = forward_and_deriv_activation_dict[ctx.activation] + + y, func_deriv = func(x) + if dropout_mask is not None: + y = y * dropout_mask + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + if dropout_mask is not None: + # order versus func_deriv does not matter + x_deriv = x_deriv * dropout_mask + + return x_deriv, weight_deriv, bias_deriv, None, None, None + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter("bias", l.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + + def forward(self, x: Tensor): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + if self.activation == "SwooshL": + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, self.weight, self.bias) + + return ActivationDropoutAndLinearFunction.apply( + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) + + +class ActivationDropoutAndLinear_lora(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + initial_scale: float = 1.0, + ): + super().__init__() + self.l = ScaledLinear_lora( + in_features=in_channels, + out_features=out_channels, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + initial_scale=initial_scale, + bias=bias, + ) + self.weight = self.l.weight + self.register_parameter("bias", self.l.bias) + + if activation == "SwooshL": + self.activation = SwooshL() + elif activation == "SwooshR": + self.activation = SwooshR() + else: + assert False, activation + self.dropout = Dropout3(dropout_p, dropout_shared_dim) + + def forward(self, x: Tensor): + return self.l(self.dropout(self.activation(x))) + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = Balancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + min_abs=0.0, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_sign: x = ", x) + print("_test_balancer_sign: y grad = ", y_grad) + print("_test_balancer_sign: x grad = ", x.grad) + + +def _test_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = Balancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + min_abs=0.2, + max_abs=0.7, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_magnitude: x = ", x) + print("_test_balancer_magnitude: y grad = ", y_grad) + print("_test_balancer_magnitude: x grad = ", x.grad) + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshL() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshR() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +def _test_piecewise_linear(): + p = PiecewiseLinear((0, 10.0)) + for x in [-100, 0, 100]: + assert p(x) == 10.0 + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: + print("x, y = ", x, y) + assert p(x) == y, (x, p(x), y) + + q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] + pq = p.max(q) + for x in x_vals: + y1 = max(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p.min(q) + for x in x_vals: + y1 = min(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p + q + for x in x_vals: + y1 = p(x) + q(x) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + + +def _test_activation_dropout_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + # actually we don't test for dropout_p != 0.0 because forward functions will give + # different answers. This is because we are using the k2 implementation of + # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # internally, messing up the random state. + for dropout_p in [0.0]: + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) + with torch.no_grad(): + m2.weight[:] = m1[2].weight + if bias: + m2.bias[:] = m1[2].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + # TEMP. + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) + if bias: + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + + # the SwooshL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_piecewise_linear() + _test_softmax() + _test_whiten() + _test_balancer_sign() + _test_balancer_magnitude() + _test_double_swish_deriv() + _test_swooshr_deriv() + _test_swooshl_deriv() + _test_activation_dropout_and_linear() diff --git a/egs/librispeech/ASR/zipformer_lora/scaling_converter.py b/egs/librispeech/ASR/zipformer_lora/scaling_converter.py new file mode 120000 index 000000000..bc7c7b5e3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/scaling_converter.py @@ -0,0 +1 @@ +../zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/subsampling.py b/egs/librispeech/ASR/zipformer_lora/subsampling.py new file mode 120000 index 000000000..d178adc2e --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/subsampling.py @@ -0,0 +1 @@ +../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py new file mode 100755 index 000000000..3ccf7d2f1 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -0,0 +1,1398 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py new file mode 100644 index 000000000..43865609a --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -0,0 +1,2522 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, + ActivationDropoutAndLinear_lora, + Balancer, + BiasNorm, + ChunkCausalDepthwiseConv1d, + Dropout2, + FloatLike, + ScaledLinear_lora, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + use_lora: bool = True, + lora_r: int = 0, + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + self.lora_r = lora_r if use_lora else 0 + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + lora_r=self.lora_r, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + ) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + lora_r: int = 0, + lora_alpha: int = 4, + lora_dropout: float = 0.0, + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.self_attn1 = SelfAttention( + embed_dim, + num_heads, + value_head_dim, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.self_attn2 = SelfAttention( + embed_dim, + num_heads, + value_head_dim, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.feed_forward1 = FeedforwardModule( + embed_dim, + (feedforward_dim * 3) // 4, + dropout, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.feed_forward2 = FeedforwardModule( + embed_dim, + feedforward_dim, + dropout, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.feed_forward3 = FeedforwardModule( + embed_dim, + (feedforward_dim * 5) // 4, + dropout, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif not self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.bypass(src_orig, src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + output, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + output, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + return output, new_states + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, downsample, dropout) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Downsample, go through encoder, upsample, in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); + True means masked position. May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + src_orig = src + src = self.downsample(src) + + src, new_states = self.encoder.streaming_forward( + src, + states=states, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src), new_states + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + lora_r: the bottleneck dimension of LoRA + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + lora_r: int = 0, + lora_alpha: int = 4, + lora_dropout: float = 0.0, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + # self.in_proj = ScaledLinear( + # embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + # ) + self.in_proj = ScaledLinear_lora( + in_features=embed_dim, + out_features=in_proj_dim, + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + initial_scale=query_head_dim**-0.25, + bias=True, + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + left_context_len + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + lora_r: int = 0, + lora_alpha: int = 4, + lora_dropout: float = 0.0, + ) -> None: + super().__init__() + self.in_proj = ScaledLinear_lora( + in_features=embed_dim, + out_features=num_heads * value_head_dim, + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + bias=True, + ) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__( + self, + embed_dim: int, + feedforward_dim: int, + dropout: FloatLike, + lora_r: int = 0, + lora_alpha: int = 4, + lora_dropout: float = 0.0, + ): + super(FeedforwardModule, self).__init__() + self.in_proj = ScaledLinear_lora( + in_features=embed_dim, + out_features=feedforward_dim, + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + bias=True, + ) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear_lora( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + + c = Zipformer2( + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) From bf2f94346c04496e3b4c9c61dfa97a474124fcae Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 18 Mar 2024 11:57:47 +0800 Subject: [PATCH 130/216] Enabling `char_level` and `compute_CER` for `aishell` recipe (#1554) * init fix Co-authored-by: Fangjun Kuang --- egs/aishell/ASR/conformer_ctc/decode.py | 8 ++++++-- egs/aishell/ASR/conformer_mmi/decode.py | 8 ++++++-- egs/aishell/ASR/pruned_transducer_stateless2/decode.py | 8 ++++++-- egs/aishell/ASR/pruned_transducer_stateless3/decode.py | 8 ++++++-- egs/aishell/ASR/pruned_transducer_stateless7/decode.py | 8 ++++++-- .../ASR/pruned_transducer_stateless7_bbpe/decode.py | 8 ++++++-- .../pruned_transducer_stateless7_streaming/decode.py | 8 ++++++-- egs/aishell/ASR/tdnn_lstm_ctc/decode.py | 10 ++++++++-- egs/aishell/ASR/transducer_stateless/decode.py | 8 ++++++-- .../ASR/transducer_stateless_modified-2/decode.py | 8 ++++++-- .../ASR/transducer_stateless_modified/decode.py | 8 ++++++-- egs/aishell/ASR/whisper/decode.py | 8 ++++++-- egs/aishell/ASR/zipformer/decode.py | 8 ++++++-- 13 files changed, 80 insertions(+), 26 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 74a7b5933..2cb476e20 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -419,7 +419,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -432,7 +432,11 @@ def save_results( results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + f, + f"{test_set_name}-{key}", + results_char, + enable_log=enable_log, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 20a855e7f..8a2daa93e 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -431,7 +431,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -444,7 +444,11 @@ def save_results( results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + f, + f"{test_set_name}-{key}", + results_char, + enable_log=enable_log, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index fb6c7c481..f41ea6776 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -390,7 +390,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -402,7 +402,11 @@ def save_results( results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True + f, + f"{test_set_name}-{key}", + results_char, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 27c64efaa..3901a330c 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -526,7 +526,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -538,7 +538,11 @@ def save_results( results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True + f, + f"{test_set_name}-{key}", + results_char, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py index 696eea906..d50bccf82 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py @@ -444,7 +444,7 @@ def save_results( for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - store_transcripts(filename=recog_path, texts=results_char) + store_transcripts(filename=recog_path, texts=results_char, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -452,7 +452,11 @@ def save_results( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True + f, + f"{test_set_name}-{key}", + results_char, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py index da9000164..46f542641 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py @@ -581,7 +581,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -594,7 +594,11 @@ def save_results( with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True + f, + f"{test_set_name}-{key}", + results_char, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py index 0e783e92b..61b929091 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -492,7 +492,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -500,7 +500,11 @@ def save_results( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index 824ca2a92..05e52f560 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -278,7 +278,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -289,7 +289,13 @@ def save_results( for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) + wer = write_error_stats( + f, + f"{test_set_name}-{key}", + results_char, + enable_log=True, + compute_CER=True, + ) test_set_wers[key] = wer logging.info("Wrote detailed error stats to {}".format(errs_filename)) diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index d23f4f883..d958a6338 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -327,7 +327,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. @@ -338,7 +338,11 @@ def save_results( results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True + f, + f"{test_set_name}-{key}", + results_char, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index d164b6890..57f7a8239 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -372,7 +372,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -384,7 +384,11 @@ def save_results( results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True + f, + f"{test_set_name}-{key}", + results_char, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 0a7d87fe8..56f3724eb 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -376,7 +376,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -388,7 +388,11 @@ def save_results( results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True + f, + f"{test_set_name}-{key}", + results_char, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index 7f841dcb7..c632d0757 100755 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -358,7 +358,7 @@ def save_results( params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -373,7 +373,11 @@ def save_results( results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + f, + f"{test_set_name}-{key}", + results_char, + enable_log=enable_log, + compute_CER=True, ) test_set_wers[key] = wer diff --git a/egs/aishell/ASR/zipformer/decode.py b/egs/aishell/ASR/zipformer/decode.py index 1968904ae..538189e52 100755 --- a/egs/aishell/ASR/zipformer/decode.py +++ b/egs/aishell/ASR/zipformer/decode.py @@ -560,7 +560,7 @@ def save_results( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -570,7 +570,11 @@ def save_results( ) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer From 9b0eae3b4aef7a26fdbf35ff7a51956c8aa10ac0 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 18 Mar 2024 17:14:29 +0800 Subject: [PATCH 131/216] fixes for init value of `diagnostics.TensorDiagnosticOptions` (#1555) --- egs/aishell/ASR/whisper/train.py | 2 +- .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/finetune.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/train.py | 2 +- egs/multi_zh-hans/ASR/whisper/train.py | 2 +- egs/wenetspeech/ASR/whisper/train.py | 2 +- egs/wenetspeech/KWS/zipformer/train.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 073b23713..6ccb8d363 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -793,7 +793,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 8e16cc0bf..a3f387636 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1035,7 +1035,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py index 47c4ed312..81c69e5e0 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -1119,7 +1119,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index dc5d0a858..728104580 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1035,7 +1035,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py index 11a22eec1..b1b60077c 100644 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -824,7 +824,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py index 4b7c1ca42..6ff500ab9 100644 --- a/egs/wenetspeech/ASR/whisper/train.py +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -803,7 +803,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index b5cf3359a..eddec7303 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -1187,7 +1187,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) From eec12f053d47a7988c83f63e87059b3d67a450ad Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 18 Mar 2024 17:53:52 +0800 Subject: [PATCH 132/216] Use piper_phonemize as text tokenizer in vctk TTS recipe (#1522) * to align with PR #1524 --- egs/vctk/TTS/README.md | 3 +- egs/vctk/TTS/local/prepare_token_file.py | 105 +--------------------- egs/vctk/TTS/local/prepare_tokens_vctk.py | 11 ++- egs/vctk/TTS/prepare.sh | 20 +++-- egs/vctk/TTS/vits/export-onnx.py | 19 ++-- egs/vctk/TTS/vits/infer.py | 9 +- egs/vctk/TTS/vits/test_onnx.py | 7 +- egs/vctk/TTS/vits/train.py | 12 +-- egs/vctk/TTS/vits/tts_datamodule.py | 5 +- 9 files changed, 56 insertions(+), 135 deletions(-) mode change 100755 => 120000 egs/vctk/TTS/local/prepare_token_file.py diff --git a/egs/vctk/TTS/README.md b/egs/vctk/TTS/README.md index c07516b77..c2703dbe2 100644 --- a/egs/vctk/TTS/README.md +++ b/egs/vctk/TTS/README.md @@ -10,7 +10,7 @@ The above information is from the [CSTR VCTK website](https://datashare.ed.ac.uk This recipe provides a VITS model trained on the VCTK dataset. -Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05), note that this model was pretrained on the Edinburgh DataShare VCTK dataset. +Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2024-03-18), note that this model was pretrained on the Edinburgh DataShare VCTK dataset. For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html). @@ -21,7 +21,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --world-size 4 \ --num-epochs 1000 \ --start-epoch 1 \ - --use-fp16 1 \ --exp-dir vits/exp \ --tokens data/tokens.txt --max-duration 350 diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py deleted file mode 100755 index c6636c3ad..000000000 --- a/egs/vctk/TTS/local/prepare_token_file.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file reads the texts in given manifest and generates the file that maps tokens to IDs. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict - -from lhotse import load_manifest - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--manifest-file", - type=Path, - default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"), - help="Path to the manifest file", - ) - - parser.add_argument( - "--tokens", - type=Path, - default=Path("data/tokens.txt"), - help="Path to the tokens", - ) - - return parser.parse_args() - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_token2id(manifest_file: Path) -> Dict[str, int]: - """Return a dict that maps token to IDs.""" - extra_tokens = [ - "", # 0 for blank - "", # 1 for sos and eos symbols. - "", # 2 for OOV - ] - all_tokens = set() - - cut_set = load_manifest(manifest_file) - - for cut in cut_set: - # Each cut only contain one supervision - assert len(cut.supervisions) == 1, len(cut.supervisions) - for t in cut.tokens: - all_tokens.add(t) - - all_tokens = extra_tokens + list(all_tokens) - - token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} - return token2id - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - manifest_file = Path(args.manifest_file) - out_file = Path(args.tokens) - - token2id = get_token2id(manifest_file) - write_mapping(out_file, token2id) diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py new file mode 120000 index 000000000..afc29a22b --- /dev/null +++ b/egs/vctk/TTS/local/prepare_token_file.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/prepare_token_file.py \ No newline at end of file diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py index 32e1c7dfa..0748eba5a 100755 --- a/egs/vctk/TTS/local/prepare_tokens_vctk.py +++ b/egs/vctk/TTS/local/prepare_tokens_vctk.py @@ -24,9 +24,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t import logging from pathlib import Path -import g2p_en import tacotron_cleaner.cleaners from lhotse import CutSet, load_manifest +from piper_phonemize import phonemize_espeak from tqdm.auto import tqdm @@ -37,17 +37,20 @@ def prepare_tokens_vctk(): partition = "all" cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - g2p = g2p_en.G2p() new_cuts = [] for cut in tqdm(cut_set): # Each cut only contains one supervision - assert len(cut.supervisions) == 1, len(cut.supervisions) + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) text = cut.supervisions[0].text # Text normalization text = tacotron_cleaner.cleaners.custom_english_cleaners(text) # Convert to phonemes - cut.tokens = g2p(text) + tokens_list = phonemize_espeak(text, "en-us") + tokens = [] + for t in tokens_list: + tokens.extend(t) + cut.tokens = tokens new_cuts.append(cut) new_cut_set = CutSet.from_cuts(new_cuts) diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh index 152c7b168..aab075312 100755 --- a/egs/vctk/TTS/prepare.sh +++ b/egs/vctk/TTS/prepare.sh @@ -78,6 +78,13 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare phoneme tokens for VCTK" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: + # refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - espnet_tts_frontend: + # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.vctk_with_token.done ]; then ./local/prepare_tokens_vctk.py mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \ @@ -111,14 +118,15 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Generate token file" - # We assume you have installed g2p_en and espnet_tts_frontend. + # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p - # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + # - piper_phonemize: + # refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - espnet_tts_frontend: + # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/tokens.txt ]; then - ./local/prepare_token_file.py \ - --manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \ - --tokens data/tokens.txt + ./local/prepare_token_file.py --tokens data/tokens.txt fi fi diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index 80d155626..d00450f08 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -97,7 +98,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key - meta.value = value + meta.value = str(value) onnx.save(model, filename) @@ -160,6 +161,7 @@ def export_model_onnx( model: nn.Module, model_filename: str, vocab_size: int, + n_speakers: int, opset_version: int = 11, ) -> None: """Export the given generator model to ONNX format. @@ -212,10 +214,15 @@ def export_model_onnx( ) meta_data = { - "model_type": "VITS", + "model_type": "vits", "version": "1", "model_author": "k2-fsa", - "comment": "VITS generator", + "comment": "icefall", # must be icefall for models from icefall + "language": "English", + "voice": "en-us", # Choose your language appropriately + "has_espeak": 1, + "n_speakers": n_speakers, + "sample_rate": 22050, # Must match the real sample rate } logging.info(f"meta_data: {meta_data}") @@ -231,8 +238,7 @@ def main(): params.update(vars(args)) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size with open(args.speakers) as f: @@ -265,6 +271,7 @@ def main(): model, model_filename, params.vocab_size, + params.num_spks, opset_version=opset_version, ) logging.info(f"Exported generator to {model_filename}") diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py index 06c25f02e..2e1abdefb 100755 --- a/egs/vctk/TTS/vits/infer.py +++ b/egs/vctk/TTS/vits/infer.py @@ -135,14 +135,16 @@ def infer_dataset( batch_size = len(batch["tokens"]) tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) # tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) speakers = ( torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]) .int() @@ -214,8 +216,7 @@ def main(): device = torch.device("cuda", 0) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size # we need cut ids to display recognition results. diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py index d85c0a27b..ae6587338 100755 --- a/egs/vctk/TTS/vits/test_onnx.py +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -122,7 +123,9 @@ def main(): model = OnnxModel(args.model_filename) text = "I went there to see the land, the people and how their system works, end quote." - tokens = tokenizer.texts_to_token_ids([text]) + tokens = tokenizer.texts_to_token_ids( + [text], intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = torch.tensor(tokens) # (1, T) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) speaker = torch.tensor([1], dtype=torch.int64) # (1, ) diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 56f167a17..55bd69327 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -342,14 +343,16 @@ def prepare_input( torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) ) - tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers @@ -812,8 +815,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size vctk = VctkTtsDataModule(args) diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py index 52fc5179f..6c785d8c3 100644 --- a/egs/vctk/TTS/vits/tts_datamodule.py +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -1,6 +1,7 @@ # Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao) +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # From 4917ac8bab2e5dc0021f17249f58b7a827a83af9 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Mon, 18 Mar 2024 11:43:29 +0100 Subject: [PATCH 133/216] allow export of onnx-streaming-models with other than 80dim input features (#1556) --- egs/librispeech/ASR/zipformer/export-onnx-streaming.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 6bc9b1858..6320e51ca 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -333,6 +333,7 @@ def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, opset_version: int = 11, + feature_dim: int = 80, ) -> None: encoder_model.encoder.__class__.forward = ( encoder_model.encoder.__class__.streaming_forward @@ -343,7 +344,7 @@ def export_encoder_model_onnx( # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling T = decode_chunk_len + encoder_model.pad_length - x = torch.rand(1, T, 80, dtype=torch.float32) + x = torch.rand(1, T, feature_dim, dtype=torch.float32) init_state = encoder_model.get_init_states() num_encoders = len(encoder_model.encoder.encoder_dim) logging.info(f"num_encoders: {num_encoders}") @@ -724,6 +725,7 @@ def main(): encoder, encoder_filename, opset_version=opset_version, + feature_dim=params.feature_dim, ) logging.info(f"Exported encoder to {encoder_filename}") From 489263e5bb54a0efb55f7c1c979e6f28f9581b34 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Mar 2024 20:11:47 +0800 Subject: [PATCH 134/216] Add streaming HLG decoding for zipformer CTC. (#1557) Note it supports only CPU. --- .github/scripts/docker/Dockerfile | 6 +- .../scripts/docker/generate_build_matrix.py | 2 +- .github/scripts/librispeech/ASR/run.sh | 41 ++ .../zipformer/export-onnx-streaming-ctc.py | 4 +- .../ASR/zipformer/export-onnx-streaming.py | 8 +- .../onnx_pretrained_ctc_HLG_streaming.py | 439 ++++++++++++++++++ 6 files changed, 492 insertions(+), 8 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index 4adb7ab5c..f64446e7e 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -36,7 +36,9 @@ RUN pip install --no-cache-dir \ \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ + cython \ dill \ + espnet_tts_frontend \ graphviz \ kaldi-decoder \ kaldi_native_io \ @@ -45,13 +47,15 @@ RUN pip install --no-cache-dir \ kaldilm \ matplotlib \ multi_quantization \ + numba \ numpy \ onnx \ onnxmltools \ onnxruntime \ + piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \ + pypinyin==0.50.0 \ pytest \ sentencepiece>=0.1.96 \ - pypinyin==0.50.0 \ six \ tensorboard \ typeguard diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 7bb8ac676..675e37c37 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20240223" kaldifeat_version = "1.25.4.dev20240223" - version = "20240223" + version = "20240318" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] torch_version += ["2.2.0", "2.2.1"] diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh index 293ed66e5..b4450afea 100755 --- a/.github/scripts/librispeech/ASR/run.sh +++ b/.github/scripts/librispeech/ASR/run.sh @@ -64,6 +64,46 @@ function run_diagnostics() { --print-diagnostics 1 } +function test_streaming_zipformer_ctc_hlg() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + rm $repo/exp-ctc-rnnt-small/*.onnx + ls -lh $repo/exp-ctc-rnnt-small + + # export models to onnx + ./zipformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 3 \ + --exp-dir $repo/exp-ctc-rnnt-small \ + --causal 1 \ + --use-ctc 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 + + ls -lh $repo/exp-ctc-rnnt-small + + for wav in 0.wav 1.wav 8k.wav; do + python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \ + --nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ + --words $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + $repo/test_wavs/$wav + done + + rm -rf $repo +} + function test_pruned_transducer_stateless_2022_03_12() { repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 @@ -1577,6 +1617,7 @@ function test_transducer_bpe_500_2021_12_23() { prepare_data run_diagnostics +test_streaming_zipformer_ctc_hlg test_pruned_transducer_stateless_2022_03_12 test_pruned_transducer_stateless2_2022_04_29 test_pruned_transducer_stateless3_2022_04_29 diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py index 3c0f74005..1eba6093b 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py @@ -32,7 +32,7 @@ This script exports a CTC model from PyTorch to ONNX. --joiner-dim 512 \ --causal True \ --chunk-size 16 \ - --left-context-frames 64 \ + --left-context-frames 128 \ --use-ctc 1 The --chunk-size in training is "16,32,64,-1", so we select one of them @@ -41,7 +41,7 @@ whose value is "64,128,256,-1". It will generate the following file inside $repo/exp: - - ctc-epoch-99-avg-1-chunk-16-left-64.onnx + - ctc-epoch-99-avg-1-chunk-16-left-128.onnx See ./onnx_pretrained-streaming-ctc.py for how to use the exported ONNX models. """ diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 6320e51ca..5d0c9ea43 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -48,7 +48,7 @@ popd --joiner-dim 512 \ --causal True \ --chunk-size 16 \ - --left-context-frames 64 + --left-context-frames 128 The --chunk-size in training is "16,32,64,-1", so we select one of them (excluding -1) during streaming export. The same applies to `--left-context`, @@ -56,9 +56,9 @@ whose value is "64,128,256,-1". It will generate the following 3 files inside $repo/exp: - - encoder-epoch-99-avg-1-chunk-16-left-64.onnx - - decoder-epoch-99-avg-1-chunk-16-left-64.onnx - - joiner-epoch-99-avg-1-chunk-16-left-64.onnx + - encoder-epoch-99-avg-1-chunk-16-left-128.onnx + - decoder-epoch-99-avg-1-chunk-16-left-128.onnx + - joiner-epoch-99-avg-1-chunk-16-left-128.onnx See ./onnx_pretrained-streaming.py for how to use the exported ONNX models. """ diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py new file mode 100755 index 000000000..a8b08de34 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming-ctc.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp-ctc-rnnt-small/*.pt" +git lfs pull --include "data/lang_bpe_500/words.txt" +git lfs pull --include "data/lang_bpe_500/HLG.fst" +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 3 \ + --exp-dir $repo/exp-ctc-rnnt-small \ + --causal 1 \ + --use-ctc 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 + +It will generate the following 2 files inside $repo/exp-ctc-rnnt-small: + + - ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx + - ctc-epoch-30-avg-3-chunk-16-left-128.onnx + +You can use either the ``int8.onnx`` model or just the ``.onnx`` model. + +3. Run this file with the exported ONNX models + +python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \ + --nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ + --words $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + $repo/test_wavs/0.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. + +Note: HLG.fst is generated directly from ../local/prepare_lang_fst.py +""" + +import argparse +import logging +from typing import Dict, List, Tuple + +import k2 +import kaldifst +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HLG", + type=str, + required=True, + help="""Path to HLG.fst.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(model_filename) + + def init_model(self, model_filename: str): + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_states() + + def init_states(self, batch_size: int = 1): + meta = self.model.get_modelmeta().custom_metadata_map + logging.info(f"meta={meta}") + + model_type = meta["model_type"] + assert model_type == "zipformer2", model_type + + decode_chunk_len = int(meta["decode_chunk_len"]) + T = int(meta["T"]) + + num_encoder_layers = meta["num_encoder_layers"] + encoder_dims = meta["encoder_dims"] + cnn_module_kernels = meta["cnn_module_kernels"] + left_context_len = meta["left_context_len"] + query_head_dims = meta["query_head_dims"] + value_head_dims = meta["value_head_dims"] + num_heads = meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def _build_model_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + model_input = {"x": x.numpy()} + model_output = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + model_input[name] = tensors[0] + model_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + model_input[name] = tensors[1] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + model_input[name] = tensors[2] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + model_input[name] = tensors[3] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + model_input[name] = tensors[4] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + model_input[name] = tensors[5] + model_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + model_input[name] = embed_states + model_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + model_input[name] = processed_lens + model_output.append(f"new_{name}") + + return model_input, model_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size) + where T' is usually equal to ((T-7)//2 - 3)//2 + """ + model_input, model_output_names = self._build_model_input_output(x) + + out = self.model.run(model_output_names, model_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + logging.info(f"Resample {sample_rate} to {expected_sample_rate}") + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + word_table = k2.SymbolTable.from_file(args.words) + model = OnnxModel(model_filename=args.nn_model) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.init_decoding() + + chunk = int(1 * sample_rate) # 1 second + start = 0 + + n = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + + # simulate streaming + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + + log_probs = model(frames) + log_probs = log_probs.squeeze(0).cpu().numpy() + + decodable = DecodableCtc(log_probs, offset=n) + n += log_probs.shape[0] + + num_processed_frames += offset + decoder.advance_decoding(decodable) + + if not decoder.reached_final(): + logging.info(f"Failed to decode {args.sound_file}") + return + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + + if not ok: + logging.info(f"Failed to get linear symbol sequence for {args.sound_file}") + return + + hyps = " ".join([word_table[i] for i in osymbols_out]).lower() + logging.info(f"\n{args.sound_file}\n{hyps}") + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() From 413220d6a449a26889e72c6110199568260a2cd4 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 18 Mar 2024 20:25:57 +0800 Subject: [PATCH 135/216] Minor fixes for the `multi_zh_en` recipe (#1526) --- egs/multi_zh_en/ASR/prepare.sh | 3 - .../ASR/zipformer/decode_stream.py | 1 + .../ASR/zipformer/streaming_decode.py | 870 +++++++++++++++++- icefall/utils.py | 8 +- 4 files changed, 875 insertions(+), 7 deletions(-) create mode 120000 egs/multi_zh_en/ASR/zipformer/decode_stream.py mode change 120000 => 100755 egs/multi_zh_en/ASR/zipformer/streaming_decode.py diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh index 9f2be5a5c..a1530be29 100755 --- a/egs/multi_zh_en/ASR/prepare.sh +++ b/egs/multi_zh_en/ASR/prepare.sh @@ -115,9 +115,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then cat ./data/lang_bpe_500/transcript_words.txt \ >> $lang_dir/text_words_segmentation - - cat ./data/lang_char/text \ - >> $lang_dir/text fi cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ diff --git a/egs/multi_zh_en/ASR/zipformer/decode_stream.py b/egs/multi_zh_en/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py deleted file mode 120000 index 13fd02a78..000000000 --- a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..7b9bd2d6c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py @@ -0,0 +1,869 @@ +#!/usr/bin/env python3 +# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import AsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + multi_dataset = MultiDataset(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_cuts = [test_sets_cuts[k] for k in test_sets] + for test_set, test_cut in zip(test_sets, test_cuts): + logging.info(f"Decoding {test_set}") + test_cut = test_cut.filter(remove_short_utt) + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/icefall/utils.py b/icefall/utils.py index 31f9801d9..2cb2edf93 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1081,9 +1081,11 @@ def write_surt_error_stats( f"{cut_id}:\t" + " ".join( ( - ref_word - if ref_word == hyp_word - else f"({ref_word}->{hyp_word})" + ( + ref_word + if ref_word == hyp_word + else f"({ref_word}->{hyp_word})" + ) for ref_word, hyp_word in ali ) ), From 9bd30853ae80b592c20dbaceeb46299b5b8d0f34 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 20 Mar 2024 15:35:14 +0800 Subject: [PATCH 136/216] Update diagnostics.py (#1562) --- icefall/diagnostics.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index a3c480c9c..37872f233 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -1,6 +1,7 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey +# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey # Zengwei Yao -# Mingshuang Luo) +# Mingshuang Luo, +# Zengrui Jin,) # # See ../LICENSE for clarification regarding multiple authors # @@ -16,9 +17,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import random from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -653,7 +655,13 @@ def attach_diagnostics( _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) - parameter.register_hook(param_backward_hook) + try: + parameter.register_hook(param_backward_hook) + except: + logging.warning( + f"Warning: could not register backward hook for parameter {name}, " + f"it might not be differentiable." + ) return ans From d5cd78a63722f2c105d0468f0a51f3ac8f404ad8 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 20 Mar 2024 16:43:45 +0800 Subject: [PATCH 137/216] Update hooks.py (#1564) --- icefall/hooks.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/icefall/hooks.py b/icefall/hooks.py index 398a5f689..1c5bd2ae6 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -1,4 +1,6 @@ -# Copyright 2021-2022 Xiaomi Corporation (authors: Zengwei Yao, Daniel Povey) +# Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao, +# Daniel Povey, +# Zengrui Jin,) # # See ../../LICENSE for clarification regarding multiple authors # @@ -77,7 +79,13 @@ def register_inf_check_hooks(model: nn.Module) -> None: if not torch.isfinite(grad.to(torch.float32).sum()): logging.warning(f"The sum of {_name}.param_grad is not finite") - parameter.register_hook(param_backward_hook) + try: + parameter.register_hook(param_backward_hook) + except: + logging.warning( + f"Warning: could not register backward hook for parameter {name}, " + f"it might not be differentiable." + ) def _test_inf_check_hooks(): From 387833fb7ced8f28b53c55ae5cc14becafa9c33c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 21 Mar 2024 12:05:30 +0800 Subject: [PATCH 138/216] Doc: Add huggingface mirror for users from China. (#1565) --- docs/source/for-dummies/environment-setup.rst | 4 ++++ docs/source/installation/index.rst | 3 +++ 2 files changed, 7 insertions(+) diff --git a/docs/source/for-dummies/environment-setup.rst b/docs/source/for-dummies/environment-setup.rst index a68e9d3ed..e257b915c 100644 --- a/docs/source/for-dummies/environment-setup.rst +++ b/docs/source/for-dummies/environment-setup.rst @@ -74,6 +74,10 @@ to install dependencies of `icefall`_: pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html + # For users from China + # 中国国内用户,如果访问不了 huggingface, 请使用 + # pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu-cn.html + # Install the latest version of lhotse pip install git+https://github.com/lhotse-speech/lhotse diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index 5a034ef5b..87318f30e 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -206,6 +206,9 @@ We will install `k2`_ from pre-compiled wheels by following .. code-block:: bash (test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html + # For users from China + # 中国国内用户,如果访问不了 huggingface, 请使用 + # pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda-cn.html Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Looking in links: https://k2-fsa.github.io/k2/cuda.html From bddc3fca7ad104629df33cbf0df69fb0a926afd3 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:08:58 +0800 Subject: [PATCH 139/216] Fix adapter in streaming_forward (#1560) --- .../ASR/zipformer_adapter/export-onnx.py | 6 +- .../ASR/zipformer_adapter/export.py | 520 ++++++++++++++++++ .../ASR/zipformer_adapter/zipformer.py | 12 + 3 files changed, 536 insertions(+), 2 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer_adapter/export.py diff --git a/egs/librispeech/ASR/zipformer_adapter/export-onnx.py b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py index ea29e8159..062396168 100755 --- a/egs/librispeech/ASR/zipformer_adapter/export-onnx.py +++ b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py @@ -27,11 +27,13 @@ popd 2. Export the model to ONNX -./zipformer/export-onnx.py \ +./zipformer_adapter/export-onnx.py \ --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ + --use-adapters 1 \ + --adapter-dim 32 \ --exp-dir $repo/exp \ --num-encoder-layers "2,2,3,4,3,2" \ --downsampling-factor "1,2,4,8,4,2" \ @@ -131,7 +133,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="zipformer/exp", + default="zipformer_adapter/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, diff --git a/egs/librispeech/ASR/zipformer_adapter/export.py b/egs/librispeech/ASR/zipformer_adapter/export.py new file mode 100755 index 000000000..72dfc081b --- /dev/null +++ b/egs/librispeech/ASR/zipformer_adapter/export.py @@ -0,0 +1,520 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer_adapter/export.py \ + --exp-dir ./zipformer_adapter/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --use-adapters 1 \ + --adapter-dim 16 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer_adapter/export.py \ + --exp-dir ./zipformer_adapter/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --use-adapters 1 \ + --adapter-dim 16 \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer_adapter/export.py \ + --exp-dir ./zipformer_adapter/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --use-adapters 1 \ + --adapter-dim 16 \ + --avg 9 + +- For streaming model: + +./zipformer_adapter/export.py \ + --exp-dir ./zipformer_adapter/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --use-adapters 1 \ + --adapter-dim 16 \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer_adapter/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer_adapter/decode_gigaspeech.py \ + --exp-dir ./zipformer_adapter/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --use-adapters 1 \ + --adapter-dim 16 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zipformer_adapter/decode.py` and `zipformer_adapter/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer_adapter/decode_gigaspeech.py \ + --exp-dir ./zipformer_adapter/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zipformer_adapter/streaming_decode.py \ + --exp-dir ./zipformer_adapter/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_finetune_arguments, add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_adapter/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index 4e4695fa5..8e2dfdd72 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -1004,6 +1004,9 @@ class Zipformer2EncoderLayer(nn.Module): ) src = src + self_attn + if self.use_adapters and self.post_sa_adapter is not None: + src = self.post_sa_adapter(src) + src_conv, cached_conv1 = self.conv_module1.streaming_forward( src, cache=cached_conv1, @@ -1016,6 +1019,9 @@ class Zipformer2EncoderLayer(nn.Module): # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) + if self.use_adapters and self.mid_adapter is not None: + src = self.mid_adapter(src) + self_attn, cached_val2 = self.self_attn2.streaming_forward( src, attn_weights=attn_weights, @@ -1031,12 +1037,18 @@ class Zipformer2EncoderLayer(nn.Module): ) src = src + src_conv + if self.use_adapters and self.post_conv_adapter is not None: + src = self.post_conv_adapter(src) + src = src + self.feed_forward3(src) src = self.norm(src) src = self.bypass(src_orig, src) + if self.use_adapters and self.adapter is not None: + src = self.adapter(src) + return ( src, cached_key, From 353469182c732d63efbe02c93e8f2408ea11c2e1 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 21 Mar 2024 15:59:43 +0800 Subject: [PATCH 140/216] fix issue in zipformer.py (#1566) --- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 61ae378d8..17a3f8719 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -788,7 +788,7 @@ class Zipformer2EncoderLayer(nn.Module): selected_attn_weights = attn_weights[0:1] if torch.jit.is_scripting() or torch.jit.is_tracing(): pass - elif not self.training and random.random() < float(self.const_attention_rate): + elif self.training and random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to # encourage these modules to do something similar to an # averaging-over-time operation. From bb9ebcfb0664762534d53db0f3fdf2aed7a52e18 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 23 Mar 2024 09:27:28 +0800 Subject: [PATCH 141/216] Fix CI (#1563) --- .github/workflows/ljspeech.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml index 25402275b..e202d21b5 100644 --- a/.github/workflows/ljspeech.yml +++ b/.github/workflows/ljspeech.yml @@ -90,7 +90,7 @@ jobs: path: ./*.wav - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 with: file_glob: true From b156b6c291bf8921053a7fe72adb212a8b4f42e1 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 26 Mar 2024 09:42:46 +0800 Subject: [PATCH 142/216] Add use-mux to finetune commands (#1567) --- egs/gigaspeech/KWS/run.sh | 1 + egs/wenetspeech/KWS/run.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh index ea04c7c9b..42e864efe 100755 --- a/egs/gigaspeech/KWS/run.sh +++ b/egs/gigaspeech/KWS/run.sh @@ -123,6 +123,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then --exp-dir zipformer/exp_finetune \ --bpe-model data/lang_bpe_500/bpe.model \ --use-fp16 1 \ + --use-mux 1 \ --decoder-dim 320 \ --joiner-dim 320 \ --num-encoder-layers 1,1,1,1,1,1 \ diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 2bdd6a5f3..f702e0817 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -126,6 +126,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then --lang-dir ./data/lang_partial_tone \ --pinyin-type partial_with_tone \ --use-fp16 1 \ + --use-mux 1 \ --decoder-dim 320 \ --joiner-dim 320 \ --num-encoder-layers 1,1,1,1,1,1 \ From 42de4591107868816d240c991b1e8acdfdfd090b Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 26 Mar 2024 10:38:21 +0800 Subject: [PATCH 143/216] Fix decoding finetune model (#1568) --- egs/gigaspeech/KWS/run.sh | 7 +++++-- egs/wenetspeech/KWS/run.sh | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh index 42e864efe..bd562ce1c 100755 --- a/egs/gigaspeech/KWS/run.sh +++ b/egs/gigaspeech/KWS/run.sh @@ -47,7 +47,9 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Decode the model." - for t in small, large; do + + export CUDA_VISIBLE_DEVICES="0" + for t in small large; do python ./zipformer/decode.py \ --epoch 12 \ --avg 2 \ @@ -140,7 +142,8 @@ fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 1: Decode the finetuned model." - for t in small, large; do + export CUDA_VISIBLE_DEVICES="0" + for t in small large; do python ./zipformer/decode.py \ --epoch 10 \ --avg 2 \ diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index f702e0817..8698e9fcc 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -48,7 +48,8 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Decode the model." - for t in small, large; do + export CUDA_VISIBLE_DEVICES="0" + for t in small large; do python ./zipformer/decode.py \ --epoch 18 \ --avg 2 \ @@ -143,7 +144,8 @@ fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 1: Decode the finetuned model." - for t in small, large; do + export CUDA_VISIBLE_DEVICES="0" + for t in small large; do python ./zipformer/decode.py \ --epoch 10 \ --avg 2 \ From 6cbddaa8e32ec5bc5c2fcc60a6d2409c7f5c7b11 Mon Sep 17 00:00:00 2001 From: Dadoou <33223302+Dadoou@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:47:38 +0800 Subject: [PATCH 144/216] Add base choice to model_name argument for whisper model. (#1573) Co-authored-by: dadoou --- egs/aishell/ASR/whisper/decode.py | 2 +- egs/aishell/ASR/whisper/train.py | 2 +- egs/multi_zh-hans/ASR/whisper/decode.py | 2 +- egs/multi_zh-hans/ASR/whisper/train.py | 2 +- egs/speechio/ASR/whisper/decode.py | 2 +- egs/wenetspeech/ASR/whisper/decode.py | 2 +- egs/wenetspeech/ASR/whisper/train.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index c632d0757..5350cb2b0 100755 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -214,7 +214,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "tiny"], + choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], help="""The model name to use. """, ) diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 6ccb8d363..d77f8c270 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -147,7 +147,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "tiny"], + choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], help="""The model name to use. """, ) diff --git a/egs/multi_zh-hans/ASR/whisper/decode.py b/egs/multi_zh-hans/ASR/whisper/decode.py index aabb80eaf..2a9c2e75d 100644 --- a/egs/multi_zh-hans/ASR/whisper/decode.py +++ b/egs/multi_zh-hans/ASR/whisper/decode.py @@ -214,7 +214,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "tiny"], + choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], help="""The model name to use. """, ) diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py index b1b60077c..7a0781d5a 100644 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -146,7 +146,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "tiny"], + choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], help="""The model name to use. """, ) diff --git a/egs/speechio/ASR/whisper/decode.py b/egs/speechio/ASR/whisper/decode.py index 001367791..70f743eee 100644 --- a/egs/speechio/ASR/whisper/decode.py +++ b/egs/speechio/ASR/whisper/decode.py @@ -215,7 +215,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "tiny"], + choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], help="""The model name to use. """, ) diff --git a/egs/wenetspeech/ASR/whisper/decode.py b/egs/wenetspeech/ASR/whisper/decode.py index 103f8d725..34b1c80ef 100755 --- a/egs/wenetspeech/ASR/whisper/decode.py +++ b/egs/wenetspeech/ASR/whisper/decode.py @@ -213,7 +213,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "tiny"], + choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], help="""The model name to use. """, ) diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py index 6ff500ab9..493f2728a 100644 --- a/egs/wenetspeech/ASR/whisper/train.py +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -145,7 +145,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "tiny"], + choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], help="""The model name to use. """, ) From 9369c2bef96f30160bfbd38a99ac4d0fc56f2c4f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 2 Apr 2024 16:08:09 +0800 Subject: [PATCH 145/216] Add comments to prepare.sh in aidatatang (#1575) --- egs/aidatatang_200zh/ASR/prepare.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 40ee2eb97..09dfd5fac 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -16,8 +16,8 @@ perturb_speed=true # # - $dl_dir/aidatatang_200zh # You can find "corpus" and "transcript" inside it. -# You can download it at -# https://openslr.org/62/ +# You can download it at https://openslr.org/62/ +# If you download the data by yourself, DON'T FORGET to extract the *.tar.gz files under corpus. dl_dir=$PWD/download From c45e9fecfb89bada0233a7b6cd9626fb6633a696 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 3 Apr 2024 11:26:24 +0800 Subject: [PATCH 146/216] support torch 2.2.2 in docker images (#1578) --- .../scripts/docker/generate_build_matrix.py | 20 ++++-- .github/workflows/build-docker-image.yml | 2 +- .github/workflows/run-docker-image.yml | 2 +- docker/torch2.0.0-cuda11.7.dockerfile | 1 + docker/torch2.1.0-cuda11.8.dockerfile | 1 + docker/torch2.1.0-cuda12.1.dockerfile | 1 + docker/torch2.2.0-cuda11.8.dockerfile | 1 + docker/torch2.2.0-cuda12.1.dockerfile | 1 + docker/torch2.2.1-cuda11.8.dockerfile | 1 + docker/torch2.2.1-cuda12.1.dockerfile | 1 + docker/torch2.2.2-cuda11.8.dockerfile | 71 +++++++++++++++++++ docker/torch2.2.2-cuda12.1.dockerfile | 71 +++++++++++++++++++ docs/source/docker/intro.rst | 2 + 13 files changed, 168 insertions(+), 7 deletions(-) create mode 100644 docker/torch2.2.2-cuda11.8.dockerfile create mode 100644 docker/torch2.2.2-cuda12.1.dockerfile diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 675e37c37..77dccb93e 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,10 +45,13 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20240223" kaldifeat_version = "1.25.4.dev20240223" - version = "20240318" + version = "20240401" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] - torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - torch_version += ["2.2.0", "2.2.1"] + torch_version = [] + torch_version += ["1.13.0", "1.13.1"] + torch_version += ["2.0.0", "2.0.1"] + torch_version += ["2.1.0", "2.1.1", "2.1.2"] + torch_version += ["2.2.0", "2.2.1", "2.2.2"] matrix = [] for p in python_version: @@ -62,10 +65,17 @@ def get_matrix(): if version_gt(p, "3.11") and not version_gt(t, "2.1"): continue + k2_version_2 = k2_version + kaldifeat_version_2 = kaldifeat_version + + if t == "2.2.2": + k2_version_2 = "1.24.4.dev20240328" + kaldifeat_version_2 = "1.25.4.dev20240329" + matrix.append( { - "k2-version": k2_version, - "kaldifeat-version": kaldifeat_version, + "k2-version": k2_version_2, + "kaldifeat-version": kaldifeat_version_2, "version": version, "python-version": p, "torch-version": t, diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index f5796d114..9198cdb7f 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index eab31cccc..a26e704c5 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v2 diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index ad23f8be7..31ff09ac6 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -1,4 +1,5 @@ FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel +# python 3.10 ENV LC_ALL C.UTF-8 diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index 4e6812b83..83b64a8d2 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -1,4 +1,5 @@ FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel +# python 3.10 ENV LC_ALL C.UTF-8 diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index c7de4cf28..ec366a898 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -1,4 +1,5 @@ FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel +# python 3.10 ENV LC_ALL C.UTF-8 diff --git a/docker/torch2.2.0-cuda11.8.dockerfile b/docker/torch2.2.0-cuda11.8.dockerfile index 0104ae870..143f0e066 100644 --- a/docker/torch2.2.0-cuda11.8.dockerfile +++ b/docker/torch2.2.0-cuda11.8.dockerfile @@ -1,4 +1,5 @@ FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel +# python 3.10 ENV LC_ALL C.UTF-8 diff --git a/docker/torch2.2.0-cuda12.1.dockerfile b/docker/torch2.2.0-cuda12.1.dockerfile index ccd5265b2..c6d5a771f 100644 --- a/docker/torch2.2.0-cuda12.1.dockerfile +++ b/docker/torch2.2.0-cuda12.1.dockerfile @@ -1,4 +1,5 @@ FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel +# python 3.10 ENV LC_ALL C.UTF-8 diff --git a/docker/torch2.2.1-cuda11.8.dockerfile b/docker/torch2.2.1-cuda11.8.dockerfile index 0528ba72f..d874134d7 100644 --- a/docker/torch2.2.1-cuda11.8.dockerfile +++ b/docker/torch2.2.1-cuda11.8.dockerfile @@ -1,4 +1,5 @@ FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel +# python 3.10 ENV LC_ALL C.UTF-8 diff --git a/docker/torch2.2.1-cuda12.1.dockerfile b/docker/torch2.2.1-cuda12.1.dockerfile index 3cdbb16ec..6e4ef290a 100644 --- a/docker/torch2.2.1-cuda12.1.dockerfile +++ b/docker/torch2.2.1-cuda12.1.dockerfile @@ -1,4 +1,5 @@ FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel +# python 3.10 ENV LC_ALL C.UTF-8 diff --git a/docker/torch2.2.2-cuda11.8.dockerfile b/docker/torch2.2.2-cuda11.8.dockerfile new file mode 100644 index 000000000..bca40a065 --- /dev/null +++ b/docker/torch2.2.2-cuda11.8.dockerfile @@ -0,0 +1,71 @@ +FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240328+cuda11.8.torch2.2.2" +ARG KALDIFEAT_VERSION="1.25.4.dev20240329+cuda11.8.torch2.2.2" +ARG TORCHAUDIO_VERSION="2.2.2+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.2.2-cuda12.1.dockerfile b/docker/torch2.2.2-cuda12.1.dockerfile new file mode 100644 index 000000000..4fb8946e7 --- /dev/null +++ b/docker/torch2.2.2-cuda12.1.dockerfile @@ -0,0 +1,71 @@ +FROM pytorch/pytorch:2.2.2-cuda12.1-cudnn8-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240328+cuda12.1.torch2.2.2" +ARG KALDIFEAT_VERSION="1.25.4.dev20240329+cuda12.1.torch2.2.2" +ARG TORCHAUDIO_VERSION="2.2.2+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index 1acaa3d4f..2f4bdb3f6 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -34,6 +34,8 @@ which will give you something like below: .. code-block:: bash + "torch2.2.2-cuda12.1" + "torch2.2.2-cuda11.8" "torch2.2.1-cuda12.1" "torch2.2.1-cuda11.8" "torch2.2.0-cuda12.1" From 87843e93821da00c6e7ffbc13c8da4d8766ab49c Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 4 Apr 2024 23:29:16 +0800 Subject: [PATCH 147/216] k2SSL: a Faster and Better Framework for Self-Supervised Speech Representation Learning (#1500) * Add k2SSL * fix flake8 * fix for black * fix for black * fix for black * Update ssl_datamodule.py * Fix bugs in HubertDataset * update comments * add librilight * add checkpoint convert script * format --------- Co-authored-by: yifanyeung Co-authored-by: zzasdf <15218404468@163.com> --- .../SSL/zipformer/asr_datamodule.py | 1 + egs/librilight/SSL/zipformer/beam_search.py | 1 + egs/librilight/SSL/zipformer/dataset.py | 1 + egs/librilight/SSL/zipformer/decode.py | 1045 +++++++ egs/librilight/SSL/zipformer/decoder.py | 1 + .../SSL/zipformer/encoder_interface.py | 1 + egs/librilight/SSL/zipformer/finetune.py | 1552 +++++++++++ egs/librilight/SSL/zipformer/hubert_ce.py | 1 + egs/librilight/SSL/zipformer/joiner.py | 1 + egs/librilight/SSL/zipformer/model.py | 1 + egs/librilight/SSL/zipformer/optim.py | 1 + egs/librilight/SSL/zipformer/pretrain.py | 1366 +++++++++ egs/librilight/SSL/zipformer/scaling.py | 1 + .../SSL/zipformer/ssl_datamodule.py | 334 +++ egs/librilight/SSL/zipformer/utils.py | 1 + .../SSL/zipformer/wav2vec2_module.py | 1 + egs/librilight/SSL/zipformer/zipformer.py | 1 + egs/librispeech/SSL/hubert/asr_datamodule.py | 287 ++ .../SSL/hubert/attention_module.py | 840 ++++++ egs/librispeech/SSL/hubert/beam_search.py | 1 + egs/librispeech/SSL/hubert/dataset.py | 367 +++ egs/librispeech/SSL/hubert/decode.py | 1045 +++++++ egs/librispeech/SSL/hubert/decode_ce.py | 1045 +++++++ egs/librispeech/SSL/hubert/decoder.py | 1 + egs/librispeech/SSL/hubert/finetune.py | 1254 +++++++++ egs/librispeech/SSL/hubert/finetune_ce.py | 1254 +++++++++ egs/librispeech/SSL/hubert/hubert.py | 984 +++++++ egs/librispeech/SSL/hubert/hubert_ce.py | 940 +++++++ egs/librispeech/SSL/hubert/joiner.py | 1 + egs/librispeech/SSL/hubert/model.py | 344 +++ egs/librispeech/SSL/hubert/optim.py | 1 + egs/librispeech/SSL/hubert/pretrain.py | 1082 ++++++++ egs/librispeech/SSL/hubert/pretrain_ce.py | 1082 ++++++++ egs/librispeech/SSL/hubert/scaling.py | 1 + egs/librispeech/SSL/hubert/ssl_datamodule.py | 341 +++ egs/librispeech/SSL/hubert/utils.py | 338 +++ egs/librispeech/SSL/hubert/wav2vec2_module.py | 593 ++++ .../local/attach_kmeans_to_supervisions.py | 52 + .../local/convert_checkpoint_from_fairseq.py | 18 + egs/librispeech/SSL/local/prepare_char.py | 259 ++ egs/librispeech/SSL/local/prepare_lang.py | 388 +++ .../SSL/local/process_librispeech4finetune.py | 107 + .../SSL/local/process_librispeech4pretrain.py | 104 + egs/librispeech/SSL/local/process_raw_cuts.py | 23 + egs/librispeech/SSL/shared | 1 + .../SSL/zipformer/asr_datamodule.py | 1 + egs/librispeech/SSL/zipformer/beam_search.py | 1 + egs/librispeech/SSL/zipformer/dataset.py | 1 + egs/librispeech/SSL/zipformer/decode.py | 1043 +++++++ egs/librispeech/SSL/zipformer/decoder.py | 1 + .../SSL/zipformer/encoder_interface.py | 1 + egs/librispeech/SSL/zipformer/finetune.py | 1551 +++++++++++ egs/librispeech/SSL/zipformer/hubert_ce.py | 601 ++++ egs/librispeech/SSL/zipformer/joiner.py | 1 + egs/librispeech/SSL/zipformer/model.py | 344 +++ egs/librispeech/SSL/zipformer/optim.py | 1 + egs/librispeech/SSL/zipformer/pretrain.py | 1380 ++++++++++ egs/librispeech/SSL/zipformer/scaling.py | 1 + .../SSL/zipformer/ssl_datamodule.py | 1 + egs/librispeech/SSL/zipformer/utils.py | 337 +++ .../SSL/zipformer/wav2vec2_module.py | 108 + egs/librispeech/SSL/zipformer/zipformer.py | 2438 +++++++++++++++++ 62 files changed, 24874 insertions(+) create mode 120000 egs/librilight/SSL/zipformer/asr_datamodule.py create mode 120000 egs/librilight/SSL/zipformer/beam_search.py create mode 120000 egs/librilight/SSL/zipformer/dataset.py create mode 100644 egs/librilight/SSL/zipformer/decode.py create mode 120000 egs/librilight/SSL/zipformer/decoder.py create mode 120000 egs/librilight/SSL/zipformer/encoder_interface.py create mode 100644 egs/librilight/SSL/zipformer/finetune.py create mode 120000 egs/librilight/SSL/zipformer/hubert_ce.py create mode 120000 egs/librilight/SSL/zipformer/joiner.py create mode 120000 egs/librilight/SSL/zipformer/model.py create mode 120000 egs/librilight/SSL/zipformer/optim.py create mode 100644 egs/librilight/SSL/zipformer/pretrain.py create mode 120000 egs/librilight/SSL/zipformer/scaling.py create mode 100644 egs/librilight/SSL/zipformer/ssl_datamodule.py create mode 120000 egs/librilight/SSL/zipformer/utils.py create mode 120000 egs/librilight/SSL/zipformer/wav2vec2_module.py create mode 120000 egs/librilight/SSL/zipformer/zipformer.py create mode 100644 egs/librispeech/SSL/hubert/asr_datamodule.py create mode 100644 egs/librispeech/SSL/hubert/attention_module.py create mode 120000 egs/librispeech/SSL/hubert/beam_search.py create mode 100644 egs/librispeech/SSL/hubert/dataset.py create mode 100644 egs/librispeech/SSL/hubert/decode.py create mode 100644 egs/librispeech/SSL/hubert/decode_ce.py create mode 120000 egs/librispeech/SSL/hubert/decoder.py create mode 100644 egs/librispeech/SSL/hubert/finetune.py create mode 100644 egs/librispeech/SSL/hubert/finetune_ce.py create mode 100644 egs/librispeech/SSL/hubert/hubert.py create mode 100644 egs/librispeech/SSL/hubert/hubert_ce.py create mode 120000 egs/librispeech/SSL/hubert/joiner.py create mode 100644 egs/librispeech/SSL/hubert/model.py create mode 120000 egs/librispeech/SSL/hubert/optim.py create mode 100644 egs/librispeech/SSL/hubert/pretrain.py create mode 100644 egs/librispeech/SSL/hubert/pretrain_ce.py create mode 120000 egs/librispeech/SSL/hubert/scaling.py create mode 100644 egs/librispeech/SSL/hubert/ssl_datamodule.py create mode 100644 egs/librispeech/SSL/hubert/utils.py create mode 100644 egs/librispeech/SSL/hubert/wav2vec2_module.py create mode 100644 egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py create mode 100644 egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py create mode 100644 egs/librispeech/SSL/local/prepare_char.py create mode 100644 egs/librispeech/SSL/local/prepare_lang.py create mode 100644 egs/librispeech/SSL/local/process_librispeech4finetune.py create mode 100644 egs/librispeech/SSL/local/process_librispeech4pretrain.py create mode 100644 egs/librispeech/SSL/local/process_raw_cuts.py create mode 120000 egs/librispeech/SSL/shared create mode 120000 egs/librispeech/SSL/zipformer/asr_datamodule.py create mode 120000 egs/librispeech/SSL/zipformer/beam_search.py create mode 120000 egs/librispeech/SSL/zipformer/dataset.py create mode 100644 egs/librispeech/SSL/zipformer/decode.py create mode 120000 egs/librispeech/SSL/zipformer/decoder.py create mode 120000 egs/librispeech/SSL/zipformer/encoder_interface.py create mode 100644 egs/librispeech/SSL/zipformer/finetune.py create mode 100644 egs/librispeech/SSL/zipformer/hubert_ce.py create mode 120000 egs/librispeech/SSL/zipformer/joiner.py create mode 100644 egs/librispeech/SSL/zipformer/model.py create mode 120000 egs/librispeech/SSL/zipformer/optim.py create mode 100644 egs/librispeech/SSL/zipformer/pretrain.py create mode 120000 egs/librispeech/SSL/zipformer/scaling.py create mode 120000 egs/librispeech/SSL/zipformer/ssl_datamodule.py create mode 100644 egs/librispeech/SSL/zipformer/utils.py create mode 100644 egs/librispeech/SSL/zipformer/wav2vec2_module.py create mode 100644 egs/librispeech/SSL/zipformer/zipformer.py diff --git a/egs/librilight/SSL/zipformer/asr_datamodule.py b/egs/librilight/SSL/zipformer/asr_datamodule.py new file mode 120000 index 000000000..b9313bffc --- /dev/null +++ b/egs/librilight/SSL/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/asr_datamodule.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/beam_search.py b/egs/librilight/SSL/zipformer/beam_search.py new file mode 120000 index 000000000..3b02c21db --- /dev/null +++ b/egs/librilight/SSL/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/dataset.py b/egs/librilight/SSL/zipformer/dataset.py new file mode 120000 index 000000000..5cd60d3b4 --- /dev/null +++ b/egs/librilight/SSL/zipformer/dataset.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/dataset.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/decode.py b/egs/librilight/SSL/zipformer/decode.py new file mode 100644 index 000000000..95643c5e1 --- /dev/null +++ b/egs/librilight/SSL/zipformer/decode.py @@ -0,0 +1,1045 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from finetune import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + + encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(batch["supervisions"]["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["cuts"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + dev_clean_dl = librispeech.test_dataloaders( + dev_clean_cuts, + do_normalize=params.do_normalize, + ) + dev_other_dl = librispeech.test_dataloaders( + dev_other_cuts, + do_normalize=params.do_normalize, + ) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders( + test_clean_cuts, + do_normalize=params.do_normalize, + ) + test_other_dl = librispeech.test_dataloaders( + test_other_cuts, + do_normalize=params.do_normalize, + ) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + # test_sets = ["dev-clean", "dev-other"] + # test_dl = [dev_clean_dl, dev_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librilight/SSL/zipformer/decoder.py b/egs/librilight/SSL/zipformer/decoder.py new file mode 120000 index 000000000..96dbfc5cd --- /dev/null +++ b/egs/librilight/SSL/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/decoder.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/encoder_interface.py b/egs/librilight/SSL/zipformer/encoder_interface.py new file mode 120000 index 000000000..30859c51b --- /dev/null +++ b/egs/librilight/SSL/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/finetune.py b/egs/librilight/SSL/zipformer/finetune.py new file mode 100644 index 000000000..50dbd5f2d --- /dev/null +++ b/egs/librilight/SSL/zipformer/finetune.py @@ -0,0 +1,1552 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For HuBERT model finetuning: +./hubert/finetune.py \ + --world-size 8 \ + --num-epochs 200 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir hubert/exp \ + --full-libri 0 \ + --max-duration 1000 + +It supports finetuning with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from hubert_ce import HubertModel +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * params.accum_grad + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + # hubert parameters + parser.add_argument( + "--label-rate", + type=float, + default=50, + ) + + parser.add_argument( + "--sample-rate", + type=float, + default=16000, + ) + + parser.add_argument( + "--extractor-mode", + type=str, + default="default", + help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group + norm with d groups in the first conv block, whereas layer_norm + has layer norms in every block (meant to use with normalize=True)""", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--conv-bias", type=bool, default=False, help="include bias in conv encoder" + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + default=1.0, + help="multiply feature extractor var grads by this", + ) + + # masking + parser.add_argument("--mask-length", type=int, default=10, help="mask_length") + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-overlap", + type=bool, + default=False, + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # channel masking + parser.add_argument( + "--mask-channel-length", + type=int, + default=10, + help="length of the mask for features (channels)", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a feature with 0", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length for channel masking", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + type=bool, + default=False, + help="whether to allow channel masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # loss computation + parser.add_argument( + "--skip-masked", + type=bool, + default=False, + help="skip computing losses over masked frames", + ) + + parser.add_argument( + "--skip-nomask", + type=bool, + default=False, + help="skip computing losses over unmasked frames", + ) + + parser.add_argument( + "--checkpoint-activations", + type=bool, + default=False, + help="recompute activations and save memory for extra compute", + ) + + parser.add_argument( + "--pred-masked-weight", + type=float, + default=1, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--pred-nomask-weight", + type=float, + default=0, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--loss-weights", + type=float, + nargs="*", + default=[10], + help="weight for masked part in ssl loss", + ) + + # FP16 optimization + parser.add_argument( + "--required-seq-len-multiple", + type=int, + default=2, + help="pad the input to encoder such that the sequence length is divisible by multiple", + ) + + parser.add_argument( + "--attn-type", type=str, default="", help="if espnet use ESPNET MHA" + ) + + parser.add_argument( + "--pos-enc-type", + type=str, + default="abs", + help="Positional encoding type to use in conformer", + ) + + parser.add_argument( + "--logit-temp", type=float, default=0.1, help="temperature to divide logits by" + ) + + parser.add_argument( + "--dropout-input", + type=float, + default=0.0, + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + default=0.0, + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--num-classes", + type=int, + nargs="*", + default=[504], + help="""num class, a little larger than the number of cluster, + the largest is for padding, + and the value should be the multiple of 4, for faster computation""", + ) + + parser.add_argument( + "--untie-final-proj", + type=bool, + default=False, + help="use separate projection for each target", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=222, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hubert/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--pretrained-dir", + type=str, + help="""The pretrained model dir. + It specifies the directory where the pretrained checkpoint is saved.""", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.001, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--sanity-check", + type=str2bool, + default=False, + help="Check if any of the batches in epoch 1 would cause OOM.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=100000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accum-grad", + type=int, + default=1, + help="""update gradient when batch_idx_train % accum_grad == 0. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + + contains number of updates happen to the model so far across + epochs. + + - sub_batch_idx_train: It contains number of batch trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "sub_batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for pruned RNN-T loss + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if hasattr(params, "pretrained_dir"): + logging.info(f"Loading {params.pretrained_dir}") + pretrained = torch.load(params.pretrained_dir) + encoder = HubertModel(params) + encoder.load_state_dict(pretrained["model"]) + else: + encoder = HubertModel(params) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `dataset.HubertAsrDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, num_frames = model( + x=audio, + padding_mask=padding_mask, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = num_frames.sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for sub_batch_idx, batch in enumerate(train_dl): + params.sub_batch_idx_train += 1 + batch_idx = sub_batch_idx // params.accum_grad + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss / params.accum_grad).backward() + + if sub_batch_idx % params.accum_grad == params.accum_grad - 1: + params.batch_idx_train += 1 + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if batch_idx % params.accum_grad != params.accum_grad - 1: + optimizer.zero_grad() + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=0) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = ( + librispeech.train_all_shuf_cuts() + if params.full_libri + else librispeech.train_clean_100_cuts() + ) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + do_normalize=params.do_normalize, + sampler_state_dict=sampler_state_dict, + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + + valid_dl = librispeech.valid_dataloaders( + valid_cuts, + do_normalize=params.do_normalize, + ) + + if params.sanity_check and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `dataset.HubertAsrDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + audio = batch["audio"] + logging.info(f"audio shape: {audio.shape}") + + y = sp.encode(batch["supervisions"]["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librilight/SSL/zipformer/hubert_ce.py b/egs/librilight/SSL/zipformer/hubert_ce.py new file mode 120000 index 000000000..2b8482f78 --- /dev/null +++ b/egs/librilight/SSL/zipformer/hubert_ce.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/hubert_ce.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/joiner.py b/egs/librilight/SSL/zipformer/joiner.py new file mode 120000 index 000000000..587823e65 --- /dev/null +++ b/egs/librilight/SSL/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/joiner.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/model.py b/egs/librilight/SSL/zipformer/model.py new file mode 120000 index 000000000..ca3daacca --- /dev/null +++ b/egs/librilight/SSL/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/model.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/optim.py b/egs/librilight/SSL/zipformer/optim.py new file mode 120000 index 000000000..bd2153ebf --- /dev/null +++ b/egs/librilight/SSL/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/optim.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/pretrain.py b/egs/librilight/SSL/zipformer/pretrain.py new file mode 100644 index 000000000..5728dbe75 --- /dev/null +++ b/egs/librilight/SSL/zipformer/pretrain.py @@ -0,0 +1,1366 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For hubert model pretraining: +./zipformer/pretrain.py \ + --world-size 8 \ + --num-epochs 400 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 87.5 \ + --accum-grad 4 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from hubert_ce import HubertModel +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from ssl_datamodule import LibriLightDataModule +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * params.accum_grad + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + # hubert parameters + parser.add_argument( + "--label-rate", + type=float, + default=50, + ) + + parser.add_argument( + "--sample-rate", + type=float, + default=16000, + ) + + parser.add_argument( + "--extractor-mode", + type=str, + default="default", + help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group + norm with d groups in the first conv block, whereas layer_norm + has layer norms in every block (meant to use with normalize=True)""", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--conv-bias", type=bool, default=False, help="include bias in conv encoder" + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + default=1.0, + help="multiply feature extractor var grads by this", + ) + + # masking + parser.add_argument("--mask-length", type=int, default=10, help="mask_length") + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-overlap", + type=bool, + default=False, + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # channel masking + parser.add_argument( + "--mask-channel-length", + type=int, + default=10, + help="length of the mask for features (channels)", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a feature with 0", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length for channel masking", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + type=bool, + default=False, + help="whether to allow channel masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # loss computation + parser.add_argument( + "--skip-masked", + type=bool, + default=False, + help="skip computing losses over masked frames", + ) + + parser.add_argument( + "--skip-nomask", + type=bool, + default=False, + help="skip computing losses over unmasked frames", + ) + + parser.add_argument( + "--checkpoint-activations", + type=bool, + default=False, + help="recompute activations and save memory for extra compute", + ) + + parser.add_argument( + "--pred-masked-weight", + type=float, + default=1, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--pred-nomask-weight", + type=float, + default=0, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--loss-weights", + type=float, + nargs="*", + default=[10], + help="weight for masked part in ssl loss", + ) + + # FP16 optimization + parser.add_argument( + "--required-seq-len-multiple", + type=int, + default=2, + help="pad the input to encoder such that the sequence length is divisible by multiple", + ) + + parser.add_argument( + "--attn-type", type=str, default="", help="if espnet use ESPNET MHA" + ) + + parser.add_argument( + "--pos-enc-type", + type=str, + default="abs", + help="Positional encoding type to use in conformer", + ) + + parser.add_argument( + "--logit-temp", type=float, default=0.1, help="temperature to divide logits by" + ) + + parser.add_argument( + "--dropout-input", + type=float, + default=0.0, + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + default=0.0, + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--num-classes", + type=int, + nargs="*", + default=[504], + help="""num class, a little larger than the number of cluster, + the largest is for padding, + and the value should be the multiple of 4, for faster computation""", + ) + + parser.add_argument( + "--untie-final-proj", + type=bool, + default=False, + help="use separate projection for each target", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=400, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=5000, + help="Eden warmup steps", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0, + help="Eden warmup start learning rate", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--sanity-check", + type=str2bool, + default=False, + help="Check if any of the batches in epoch 1 would cause OOM.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=100000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accum-grad", + type=int, + default=4, + help="""update gradient when batch_idx_train % accum_grad == 0. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--max-sample-size", + type=float, + default=250000, + help="max sample size", + ) + + parser.add_argument( + "--min-sample-size", + type=float, + default=32000, + help="min sample size", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of updates happen to the model so far across + epochs. + + - sub_batch_idx_train: It contains number of batch trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "sub_batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_model(params: AttributeDict) -> nn.Module: + model = HubertModel(params) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `dataset.HubertDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + kmeans = batch["kmeans"].to(device) + + with torch.set_grad_enabled(is_training): + loss, num_masked_tokens, logging_output = model( + source=audio, target_list=[kmeans], padding_mask=padding_mask + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = num_masked_tokens + for item in logging_output: + info[item] = logging_output[item] + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for sub_batch_idx, batch in enumerate(train_dl): + params.sub_batch_idx_train += 1 + batch_idx = sub_batch_idx // params.accum_grad + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + batch_size = batch["kmeans"].shape[0] + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss / params.accum_grad).backward() + + if sub_batch_idx % params.accum_grad == params.accum_grad - 1: + params.batch_idx_train += 1 + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if batch_idx % params.accum_grad != params.accum_grad - 1: + optimizer.zero_grad() + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + params.warmup_batches, + params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librilight = LibriLightDataModule(args) + + train_cuts = librilight.train_all_shuf_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if ( + c.duration < params.min_sample_size / params.sample_rate + or c.duration > params.max_sample_size / params.sample_rate + ): + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librilight.train_dataloaders( + train_cuts, + sample_rate=params.sample_rate, + label_rate=params.label_rate, + random_crop=params.random_crop, + pad_audio=False, + num_classes=params.num_classes, + do_normalize=params.do_normalize, + sampler_state_dict=sampler_state_dict, + ) + + valid_cuts = librilight.dev_clean_cuts() + # valid_cuts += librilight.dev_other_cuts() + valid_cuts = valid_cuts.filter(remove_short_and_long_utt) + + valid_dl = librilight.valid_dataloaders( + valid_cuts, + sample_rate=params.sample_rate, + label_rate=params.label_rate, + random_crop=params.random_crop, + pad_audio=False, + num_classes=params.num_classes, + do_normalize=params.do_normalize, + ) + + if params.sanity_check and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `dataset.HubertDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + audio = batch["audio"] + logging.info(f"audio shape: {audio.shape}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriLightDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librilight/SSL/zipformer/scaling.py b/egs/librilight/SSL/zipformer/scaling.py new file mode 120000 index 000000000..24b661dfb --- /dev/null +++ b/egs/librilight/SSL/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/scaling.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/ssl_datamodule.py b/egs/librilight/SSL/zipformer/ssl_datamodule.py new file mode 100644 index 000000000..dc0dbec6c --- /dev/null +++ b/egs/librilight/SSL/zipformer/ssl_datamodule.py @@ -0,0 +1,334 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import glob +import logging +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from dataset import HubertDataset +from lhotse import CutSet, combine, load_manifest_lazy +from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriLightDataModule: + """ + DataModule for SSL experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in SSL + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + + This class should be derived for specific corpora used in SSL tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR SSL related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/kmeans"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=float, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + group.add_argument( + "--do-normalize", + type=str2bool, + default=True, + help="whether to normalize the data", + ) + group.add_argument( + "--random-crop", + type=str2bool, + default=True, + help="audio sample rate", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sample_rate: float = 16000, + label_rate: float = 50, + random_crop: bool = True, + pad_audio: bool = False, + num_classes: list = [504], + do_normalize: bool = True, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = HubertDataset( + sample_rate=sample_rate, + label_rate=label_rate, + random_crop=random_crop, + pad_audio=pad_audio, + num_classes=num_classes, + do_normalize=do_normalize, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + sample_rate: float = 16000, + label_rate: float = 50, + random_crop: bool = True, + pad_audio: bool = False, + num_classes: list = [504], + do_normalize: bool = True, + ) -> DataLoader: + logging.info("About to create dev dataset") + validate = HubertDataset( + sample_rate=sample_rate, + label_rate=label_rate, + random_crop=random_crop, + pad_audio=pad_audio, + num_classes=num_classes, + do_normalize=do_normalize, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + sample_rate: float = 16000, + label_rate: float = 50, + random_crop: bool = True, + pad_audio: bool = False, + num_classes: list = [504], + do_normalize: bool = True, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = HubertDataset( + sample_rate=sample_rate, + label_rate=label_rate, + random_crop=random_crop, + pad_audio=pad_audio, + num_classes=num_classes, + do_normalize=do_normalize, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def small_cuts(self) -> CutSet: + logging.info("About to get small cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librilight_cuts_small.jsonl.gz" + ) + + @lru_cache() + def medium_cuts(self) -> CutSet: + logging.info("About to get medium cuts") + filenames = glob.glob( + f"{self.args.manifest_dir}/medium_splits/librilight_cuts_medium.*.jsonl.gz" + ) + pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + sorted_filenames = [f[1] for f in idx_filenames] + logging.info( + f"Loading LibriLight medium {len(sorted_filenames)} splits in lazy mode" + ) + + return combine(load_manifest_lazy(p) for p in sorted_filenames) + + @lru_cache() + def large_cuts(self) -> CutSet: + logging.info("About to get large cuts") + filenames = glob.glob( + f"{self.args.manifest_dir}/large_splits/librilight_cuts_large.*.jsonl.gz" + ) + pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + sorted_filenames = [f[1] for f in idx_filenames] + logging.info( + f"Loading LibriLight large {len(sorted_filenames)} splits in lazy mode" + ) + + return combine(load_manifest_lazy(p) for p in sorted_filenames) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info("About to get the shuffled small, medium and large cuts") + small_cuts = self.small_cuts() + medium_cuts = self.medium_cuts() + large_cuts = self.large_cuts() + return CutSet.mux( + small_cuts, + medium_cuts, + large_cuts, + weights=[ + 122867, # len(small_cuts) + 1104071, # len(medium_cuts) + 11012085, # len(large_cuts) + ], + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) diff --git a/egs/librilight/SSL/zipformer/utils.py b/egs/librilight/SSL/zipformer/utils.py new file mode 120000 index 000000000..119992bdb --- /dev/null +++ b/egs/librilight/SSL/zipformer/utils.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/utils.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/wav2vec2_module.py b/egs/librilight/SSL/zipformer/wav2vec2_module.py new file mode 120000 index 000000000..81ad701e4 --- /dev/null +++ b/egs/librilight/SSL/zipformer/wav2vec2_module.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/wav2vec2_module.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/zipformer.py b/egs/librilight/SSL/zipformer/zipformer.py new file mode 120000 index 000000000..5b3da8cd5 --- /dev/null +++ b/egs/librilight/SSL/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/SSL/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/SSL/hubert/asr_datamodule.py b/egs/librispeech/SSL/hubert/asr_datamodule.py new file mode 100644 index 000000000..3746d8a3a --- /dev/null +++ b/egs/librispeech/SSL/hubert/asr_datamodule.py @@ -0,0 +1,287 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from dataset import HubertAsrDataset +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/wav"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=float, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + group.add_argument( + "--do-normalize", + type=str2bool, + default=True, + help="whether to normalize the data", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + do_normalize: bool, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = HubertAsrDataset(do_normalize=do_normalize) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader: + logging.info("About to create dev dataset") + validate = HubertAsrDataset(do_normalize=do_normalize) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet, do_normalize: bool) -> DataLoader: + logging.debug("About to create test dataset") + test = HubertAsrDataset(do_normalize=do_normalize) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + train_clean_100_cuts = self.train_clean_100_cuts() + train_clean_360_cuts = self.train_clean_360_cuts() + train_other_500_cuts = self.train_other_500_cuts() + return CutSet.mux( + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + weights=[ + 28539, # len(train_clean_100_cuts) + 104014, # len(train_clean_360_cuts) + 148688, # len(train_other_500_cuts) + ], + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/librispeech/SSL/hubert/attention_module.py b/egs/librispeech/SSL/hubert/attention_module.py new file mode 100644 index 000000000..39ef8698e --- /dev/null +++ b/egs/librispeech/SSL/hubert/attention_module.py @@ -0,0 +1,840 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import utils +from torch import Tensor, nn +from torch.nn import Parameter +from utils import FairseqDropout, quant_noise + +_xformers_available = False + + +# TODO: move this into xformers? +# TODO: uint8 input type should just output a bool +def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None): + """ + call to pytorch multihead accepts three mask types: + - ByteTensor where non-zero means to mask + - FloatTensor which is an additive mask + - BoolTensor where True means to mask + xFormers currently accepts boolean and additive maks. For boolean masks + the values have opposite meaning. For a BoolTensor True mean to keep the value. + """ + float_types = [torch.float, torch.float16] + # If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool. + additive = mask.dtype in float_types + # If to_dype is not specified, keep same dtype as mask. + to_dtype = mask.dtype if to_dtype is None else to_dtype + to_additive = to_dtype in float_types + + if additive: + if to_additive: + return mask.to(to_dtype) + mask = mask < 0 + + if to_additive: + # return additive mask + new_mask = torch.zeros_like(mask, dtype=to_dtype) + new_mask = new_mask.masked_fill_(mask, -float("inf")) + return new_mask + + # In xFormers True is value to keep rather than value to mask + mask = ~mask.to(torch.bool) + mask = mask.to(to_dtype) + return mask + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + dictionary=None, + q_noise=0.0, + qn_block_size=8, + # TODO: pass in config rather than string. + # config defined in xformers.components.attention.AttentionConfig + xformers_att_config: Optional[str] = None, + xformers_blocksparse_layout: Optional[ + torch.Tensor + ] = None, # This should be part of the config + xformers_blocksparse_blocksize: Optional[ + int + ] = 16, # This should be part of the config + ): + super().__init__() + + self.use_xformers = False + if self.use_xformers and not _xformers_available: + raise ImportError("\n\n Please install xFormers.") + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert ( + not self.self_attention or self.qkv_same_dim + ), "Self-attention requires query, key and value to be of the same size" + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + self.beam_size = 1 + self.reset_parameters() + + self.onnx_trace = False + self.skip_embed_dim_check = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def _get_reserve_head_index(self, num_heads_to_keep: int): + k_proj_heads_norm = [] + q_proj_heads_norm = [] + v_proj_heads_norm = [] + + for i in range(self.num_heads): + start_idx = i * self.head_dim + end_idx = (i + 1) * self.head_dim + k_proj_heads_norm.append( + torch.sum( + torch.abs( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() + ) + q_proj_heads_norm.append( + torch.sum( + torch.abs( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() + ) + v_proj_heads_norm.append( + torch.sum( + torch.abs( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() + ) + + heads_norm = [] + for i in range(self.num_heads): + heads_norm.append( + k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] + ) + + sorted_head_index = sorted( + range(self.num_heads), key=lambda k: heads_norm[k], reverse=True + ) + reserve_head_index = [] + for i in range(num_heads_to_keep): + start = sorted_head_index[i] * self.head_dim + end = (sorted_head_index[i] + 1) * self.head_dim + reserve_head_index.append((start, end)) + return reserve_head_index + + def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): + new_q_weight = [] + new_q_bias = [] + new_k_weight = [] + new_k_bias = [] + new_v_weight = [] + new_v_bias = [] + new_out_proj_weight = [] + + for ele in reserve_head_index: + start_idx, end_idx = ele + new_q_weight.append( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) + + new_k_weight.append( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + + new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) + + new_v_weight.append( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) + + new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) + + new_q_weight = torch.cat(new_q_weight).detach() + new_k_weight = torch.cat(new_k_weight).detach() + new_v_weight = torch.cat(new_v_weight).detach() + new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach() + new_q_weight.requires_grad = True + new_k_weight.requires_grad = True + new_v_weight.requires_grad = True + new_out_proj_weight.requires_grad = True + + new_q_bias = torch.cat(new_q_bias).detach() + new_q_bias.requires_grad = True + + new_k_bias = torch.cat(new_k_bias).detach() + new_k_bias.requires_grad = True + + new_v_bias = torch.cat(new_v_bias).detach() + new_v_bias.requires_grad = True + + self.q_proj.weight = torch.nn.Parameter(new_q_weight) + self.q_proj.bias = torch.nn.Parameter(new_q_bias) + + self.k_proj.weight = torch.nn.Parameter(new_k_weight) + self.k_proj.bias = torch.nn.Parameter(new_k_bias) + + self.v_proj.weight = torch.nn.Parameter(new_v_weight) + self.v_proj.bias = torch.nn.Parameter(new_v_bias) + + self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight) + + self.num_heads = len(reserve_head_index) + self.embed_dim = self.head_dim * self.num_heads + self.q_proj.out_features = self.embed_dim + self.k_proj.out_features = self.embed_dim + self.v_proj.out_features = self.embed_dim + + def _set_skip_embed_dim_check(self): + self.skip_embed_dim_check = True + + def _pad_masks( + self, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + if attn_mask is not None: + shape = attn_mask.size()[:-1] + torch.Size([1]) + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) + if key_padding_mask is not None: + shape = key_padding_mask.size()[:-1] + torch.Size([1]) + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(shape), + ], + dim=-1, + ) + return key_padding_mask, attn_mask + + def _add_bias( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + bsz: int, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + assert self.bias_k is not None + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _append_zero_attn( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], + dim=-2, + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], + dim=-2, + ) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def forward( + self, + query: Tensor, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + if not self.skip_embed_dim_check: + assert ( + embed_dim == self.embed_dim + ), f"query dim {embed_dim} != {self.embed_dim}" + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert value is not None + assert src_len, key_bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + # The Multihead attention implemented in pytorch forces strong dimension check + # for input embedding dimention and K,Q,V projection dimension. + # Since pruning will break the dimension check and it is not easy to modify the pytorch API, + # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check + and not self.skip_embed_dim_check + ): + assert key is not None and value is not None + + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask.bool() if key_padding_mask is not None else None, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + if self.beam_size > 1 and bsz == key.size(1): + # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ + :, :, 0, : + ] + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view( + -1, self.beam_size, key_padding_mask.size(1) + )[:, 0, :] + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + kv_bsz = bsz # need default value for scripting + if k is not None: + kv_bsz = k.size(1) + k = ( + k.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + kv_bsz = _prev_key.size(0) + prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + assert kv_bsz == _prev_value.size(0) + prev_value = _prev_value.view( + kv_bsz * self.num_heads, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=kv_bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view( + kv_bsz, self.num_heads, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == kv_bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = torch.einsum( + "bxhtd,bhsd->bxhts", + q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), + k.view((kv_bsz, self.num_heads) + k.size()[1:]), + ) + attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) + else: + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [ + bsz * self.num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.view( + kv_bsz, -1, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn: Optional[Tensor] = None + if self.encoder_decoder_attention and bsz != kv_bsz: + attn = torch.einsum( + "bxhts,bhsd->bxhtd", + attn_probs.view( + ( + kv_bsz, + -1, + self.num_heads, + ) + + attn_probs.size()[1:] + ), + v.view( + ( + kv_bsz, + self.num_heads, + ) + + v.size()[1:] + ), + ) + attn = attn.reshape((-1,) + attn.size()[-2:]) + else: + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [ + bsz * self.num_heads, + tgt_len, + self.head_dim, + ] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention: + if input_buffer_k.size(0) * self.beam_size == new_order.size(0): + return incremental_state + elif self.beam_size > 1: + input_buffer[k] = input_buffer_k.index_select( + 0, + new_order.reshape(-1, self.beam_size)[:, 0] + // self.beam_size, + ) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def set_beam_size(self, beam_size): + """Used for effiecient beamable enc-dec attention""" + self.beam_size = beam_size + + def _get_input_buffer( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value diff --git a/egs/librispeech/SSL/hubert/beam_search.py b/egs/librispeech/SSL/hubert/beam_search.py new file mode 120000 index 000000000..f4d4b5732 --- /dev/null +++ b/egs/librispeech/SSL/hubert/beam_search.py @@ -0,0 +1 @@ +../../ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/SSL/hubert/dataset.py b/egs/librispeech/SSL/hubert/dataset.py new file mode 100644 index 000000000..76edfb340 --- /dev/null +++ b/egs/librispeech/SSL/hubert/dataset.py @@ -0,0 +1,367 @@ +# Copyright 2024 Xiaomi Corporation (authors: Yifan Yang) +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Any, Dict, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.collation import read_audio_from_cuts +from torch.utils.data.dataloader import default_collate + + +class HubertDataset(torch.utils.data.Dataset): + """ + In this implementation, there will always be a single channel. + + Returns: + + .. code-block:: + + { + 'audio': (B x NumSamples) float tensor + } + """ + + def __init__( + self, + max_sample_size: Optional[int] = None, + sample_rate: float = 16000, + label_rate: float = 50, + random_crop: bool = True, + pad_audio: bool = False, + num_classes: list = [504], + do_normalize: bool = True, + ) -> None: + super().__init__() + self.sample_rate = sample_rate + self.label_rate = label_rate + self.random_crop = random_crop + self.pad_audio = pad_audio + self.num_classes = num_classes + self.normalize = do_normalize + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: + self._validate(cuts) + audio, _ = read_audio_from_cuts(cuts) + for i, item in enumerate(audio): + audio[i] = self.postprocess(item, self.sample_rate) + audio_lens = [cut.num_samples for cut in cuts] + + if self.pad_audio: + audio_size = min(max(audio_lens), self.max_sample_size) + else: + audio_size = min(min(audio_lens), self.max_sample_size) + + audio, padding_mask, audio_starts = self.collater_audio( + audio, audio_lens, audio_size + ) + + kmeans = [cut.custom["kmeans"] for cut in cuts] + kmeans = [ + torch.tensor([int(item) for item in label.split()], dtype=torch.int64) + for label in kmeans + ] + kmeans, _ = self.collater_frm_label(kmeans, audio_size, audio_starts) + + return { + "cuts": cuts, + "audio": audio, + "padding_mask": padding_mask, + "kmeans": kmeans, + } + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav + + def _validate(self, cuts: CutSet) -> None: + validate(cuts) + assert all(cut.has_recording for cut in cuts) + + def crop_to_max_size(self, wav, target_size): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + return wav[start:end], start + + def collater_audio(self, audios, audio_lens, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + # if self.pad_audio else None + ) + audio_starts = [0 for _ in audios] + for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)): + audio = audio[:audio_len] + diff = audio_len - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size + ) + return collated_audios, padding_mask, audio_starts + + def collate_tokens( + self, + values, + pad_idx, + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + pad_to_length=None, + pad_to_multiple=1, + pad_to_bsz=None, + ): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + + batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz) + res = values[0].new(batch_size, size).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if move_eos_to_beginning: + if eos_idx is None: + # if no eos_idx is specified, then use the last token in src + dst[0] = src[-1] + else: + dst[0] = eos_idx + dst[1:] = src[:-1] + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) + return res + + def collater_frm_label(self, targets, audio_size, audio_starts): + label_rate = self.label_rate + pad = self.num_classes[0] - 1 + assert label_rate > 0 + s2f = label_rate / self.sample_rate + frm_starts = [int(round(s * s2f)) for s in audio_starts] + frm_size = int(round(audio_size * s2f)) + if not self.pad_audio: + rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] + frm_size = min(frm_size, *rem_size) + targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)] + + lengths = torch.LongTensor([len(t) for t in targets]) + targets = self.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths + + +class HubertAsrDataset(torch.utils.data.Dataset): + """ + In this implementation, there will always be a single channel. + + Returns: + + .. code-block:: + + { + 'audio': (B x NumSamples) float tensor + } + """ + + def __init__( + self, + max_sample_size: Optional[int] = None, + sample_rate: float = 16000, + random_crop: bool = True, + pad_audio: bool = True, + do_normalize: bool = True, + ) -> None: + super().__init__() + self.sample_rate = sample_rate + self.random_crop = random_crop + self.pad_audio = pad_audio + self.normalize = do_normalize + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: + self._validate(cuts) + audio, _ = read_audio_from_cuts(cuts) + for i, item in enumerate(audio): + audio[i] = self.postprocess(item, self.sample_rate) + audio_lens = [cut.num_samples for cut in cuts] + if self.pad_audio: + audio_size = min(max(audio_lens), self.max_sample_size) + else: + audio_size = min(min(audio_lens), self.max_sample_size) + + audio, padding_mask, audio_starts = self.collater_audio( + audio, audio_lens, audio_size + ) + + return { + "cuts": cuts, + "audio": audio, + "padding_mask": padding_mask, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav + + def _validate(self, cuts: CutSet) -> None: + validate(cuts) + assert all(cut.has_recording for cut in cuts) + + def crop_to_max_size(self, wav, target_size): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + return wav[start:end], start + + def collater_audio(self, audios, audio_lens, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + # if self.pad_audio else None + ) + audio_starts = [0 for _ in audios] + for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)): + audio = audio[:audio_len] + diff = audio_len - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size + ) + return collated_audios, padding_mask, audio_starts + + def collate_tokens( + self, + values, + pad_idx, + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + pad_to_length=None, + pad_to_multiple=1, + pad_to_bsz=None, + ): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + + batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz) + res = values[0].new(batch_size, size).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if move_eos_to_beginning: + if eos_idx is None: + # if no eos_idx is specified, then use the last token in src + dst[0] = src[-1] + else: + dst[0] = eos_idx + dst[1:] = src[:-1] + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) + return res + + +if __name__ == "__main__": + from lhotse import load_manifest_lazy + from lhotse.dataset import DynamicBucketingSampler + from torch.utils.data import DataLoader + + dataset = HubertDataset() + cuts = load_manifest_lazy("data/fbank2/librispeech_cuts_train-clean-100.jsonl.gz") + sampler = DynamicBucketingSampler( + cuts, + max_duration=100, + shuffle=False, + ) + dl = DataLoader( + dataset, + batch_size=None, + sampler=sampler, + num_workers=2, + ) + + for batch_idx, batch in enumerate(dl): + print(batch) + break diff --git a/egs/librispeech/SSL/hubert/decode.py b/egs/librispeech/SSL/hubert/decode.py new file mode 100644 index 000000000..837061b8c --- /dev/null +++ b/egs/librispeech/SSL/hubert/decode.py @@ -0,0 +1,1045 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from finetune import add_model_arguments, get_model, get_params +from hubert import add_hubert_arguments + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hubert/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + add_hubert_arguments(parser) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + + encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(batch["supervisions"]["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["cuts"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + dev_clean_dl = librispeech.test_dataloaders( + dev_clean_cuts, + do_normalize=params.do_normalize, + ) + dev_other_dl = librispeech.test_dataloaders( + dev_other_cuts, + do_normalize=params.do_normalize, + ) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders( + test_clean_cuts, + do_normalize=params.do_normalize, + ) + test_other_dl = librispeech.test_dataloaders( + test_other_cuts, + do_normalize=params.do_normalize, + ) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/hubert/decode_ce.py b/egs/librispeech/SSL/hubert/decode_ce.py new file mode 100644 index 000000000..a8d8bc9c2 --- /dev/null +++ b/egs/librispeech/SSL/hubert/decode_ce.py @@ -0,0 +1,1045 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./hubert/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./hubert/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from finetune_ce import add_model_arguments, get_model, get_params +from hubert_ce import add_hubert_arguments + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hubert/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + add_hubert_arguments(parser) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + + encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(batch["supervisions"]["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["cuts"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + dev_clean_dl = librispeech.test_dataloaders( + dev_clean_cuts, + do_normalize=params.do_normalize, + ) + dev_other_dl = librispeech.test_dataloaders( + dev_other_cuts, + do_normalize=params.do_normalize, + ) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders( + test_clean_cuts, + do_normalize=params.do_normalize, + ) + test_other_dl = librispeech.test_dataloaders( + test_other_cuts, + do_normalize=params.do_normalize, + ) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/hubert/decoder.py b/egs/librispeech/SSL/hubert/decoder.py new file mode 120000 index 000000000..a2138e5da --- /dev/null +++ b/egs/librispeech/SSL/hubert/decoder.py @@ -0,0 +1 @@ +../../ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py new file mode 100644 index 000000000..201847aed --- /dev/null +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -0,0 +1,1254 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For HuBERT model finetuning: +./hubert/finetune.py \ + --world-size 8 \ + --num-epochs 200 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir hubert/exp \ + --full-libri 0 \ + --max-duration 200 + +It supports finetuning with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from hubert import HubertModel, add_hubert_arguments +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * params.accum_grad + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=222, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hubert/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--pretrained-dir", + type=str, + help="""The pretrained model dir. + It specifies the directory where the pretrained checkpoint is saved.""", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.001, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--sanity-check", + type=str2bool, + default=False, + help="Check if any of the batches in epoch 1 would cause OOM.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=100000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accum-grad", + type=int, + default=1, + help="""update gradient when batch_idx_train % accum_grad == 0. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_hubert_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + + contains number of updates happen to the model so far across + epochs. + + - sub_batch_idx_train: It contains number of batch trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "sub_batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for pruned RNN-T loss + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if hasattr(params, "pretrained_dir"): + logging.info(f"Loading {params.pretrained_dir}") + pretrained = torch.load(params.pretrained_dir) + encoder = HubertModel(params) + encoder.load_state_dict(pretrained["model"]) + else: + encoder = HubertModel(params) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_embed_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_embed_dim, + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `dataset.HubertAsrDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, num_frames = model( + x=audio, + padding_mask=padding_mask, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = num_frames.sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for sub_batch_idx, batch in enumerate(train_dl): + params.sub_batch_idx_train += 1 + batch_idx = sub_batch_idx // params.accum_grad + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss / params.accum_grad).backward() + + if sub_batch_idx % params.accum_grad == params.accum_grad - 1: + params.batch_idx_train += 1 + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if batch_idx % params.accum_grad != params.accum_grad - 1: + optimizer.zero_grad() + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = ( + librispeech.train_all_shuf_cuts() + if params.full_libri + else librispeech.train_clean_100_cuts() + ) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + do_normalize=params.do_normalize, + sampler_state_dict=sampler_state_dict, + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + + valid_dl = librispeech.valid_dataloaders( + valid_cuts, + do_normalize=params.do_normalize, + ) + + if params.sanity_check and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `dataset.HubertAsrDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + audio = batch["audio"] + logging.info(f"audio shape: {audio.shape}") + + y = sp.encode(batch["supervisions"]["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py new file mode 100644 index 000000000..e69a5a8cd --- /dev/null +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -0,0 +1,1254 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For HuBERT model finetuning: +./hubert/finetune.py \ + --world-size 8 \ + --num-epochs 200 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir hubert/exp \ + --full-libri 0 \ + --max-duration 200 + +It supports finetuning with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from hubert_ce import HubertModel, add_hubert_arguments +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * params.accum_grad + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=222, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hubert/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--pretrained-dir", + type=str, + help="""The pretrained model dir. + It specifies the directory where the pretrained checkpoint is saved.""", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.001, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--sanity-check", + type=str2bool, + default=False, + help="Check if any of the batches in epoch 1 would cause OOM.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=100000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accum-grad", + type=int, + default=1, + help="""update gradient when batch_idx_train % accum_grad == 0. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_hubert_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + + contains number of updates happen to the model so far across + epochs. + + - sub_batch_idx_train: It contains number of batch trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "sub_batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for pruned RNN-T loss + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if hasattr(params, "pretrained_dir"): + logging.info(f"Loading {params.pretrained_dir}") + pretrained = torch.load(params.pretrained_dir) + encoder = HubertModel(params) + encoder.load_state_dict(pretrained["model"]) + else: + encoder = HubertModel(params) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_embed_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_embed_dim, + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `dataset.HubertAsrDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, num_frames = model( + x=audio, + padding_mask=padding_mask, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = num_frames.sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for sub_batch_idx, batch in enumerate(train_dl): + params.sub_batch_idx_train += 1 + batch_idx = sub_batch_idx // params.accum_grad + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss / params.accum_grad).backward() + + if sub_batch_idx % params.accum_grad == params.accum_grad - 1: + params.batch_idx_train += 1 + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if batch_idx % params.accum_grad != params.accum_grad - 1: + optimizer.zero_grad() + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = ( + librispeech.train_all_shuf_cuts() + if params.full_libri + else librispeech.train_clean_100_cuts() + ) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + do_normalize=params.do_normalize, + sampler_state_dict=sampler_state_dict, + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + + valid_dl = librispeech.valid_dataloaders( + valid_cuts, + do_normalize=params.do_normalize, + ) + + if params.sanity_check and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `dataset.HubertAsrDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + audio = batch["audio"] + logging.info(f"audio shape: {audio.shape}") + + y = sp.encode(batch["supervisions"]["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/hubert/hubert.py b/egs/librispeech/SSL/hubert/hubert.py new file mode 100644 index 000000000..f800044f4 --- /dev/null +++ b/egs/librispeech/SSL/hubert/hubert.py @@ -0,0 +1,984 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils import GradMultiply, LayerNorm +from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError(f"this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + + +def add_hubert_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--label-rate", + type=float, + default=50, + ) + + parser.add_argument( + "--sample-rate", + type=float, + default=16000, + ) + + parser.add_argument( + "--extractor-mode", + type=str, + default="default", + help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group + norm with d groups in the first conv block, whereas layer_norm + has layer norms in every block (meant to use with normalize=True)""", + ) + parser.add_argument( + "--encoder-layers", + type=int, + default=12, + help="num encoder layers in the transformer", + ) + + parser.add_argument( + "--encoder-embed-dim", + type=int, + default=768, + help="encoder embedding dimension", + ) + + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + default=3072, + help="encoder embedding dimension for FFN", + ) + + parser.add_argument( + "--encoder-attention-heads", + type=int, + default=12, + help="num encoder attention heads", + ) + + parser.add_argument( + "--activation-fn", + type=str, + choices=[ + "relu", + "gelu", + "gelu_fast", + "gelu_accurate", + "tanh", + "linear", + ], + default="gelu", + help="activation function to use", + ) + + parser.add_argument( + "--layer-type", + type=str, + choices=["transformer", "conformer", "trf_adp"], + default="transformer", + help="layer type in encoder", + ) + + # dropouts + parser.add_argument( + "--dropout", + type=float, + default=0.1, + help="dropout probability for the transformer", + ) + + parser.add_argument( + "--attention-dropout", + type=float, + default=0.1, + help="dropout probability for attention weights", + ) + + parser.add_argument( + "--activation-dropout", + type=float, + default=0.0, + help="dropout probability after activation in FFN", + ) + + parser.add_argument( + "--encoder-layerdrop", + type=float, + default=0.0, + help="probability of dropping a tarnsformer layer", + ) + + parser.add_argument( + "--dropout-input", + type=float, + default=0.0, + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + default=0.0, + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--final-dim", + type=int, + default=0, + help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0", + ) + + parser.add_argument( + "--untie-final-proj", + type=bool, + default=False, + help="use separate projection for each target", + ) + + parser.add_argument( + "--layer-norm-first", + type=bool, + default=False, + help="apply layernorm first in the transformer", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--conv-bias", + type=bool, + default=False, + help="include bias in conv encoder", + ) + + parser.add_argument( + "--logit-temp", + type=float, + default=0.1, + help="temperature to divide logits by", + ) + + parser.add_argument( + "--target-glu", + type=bool, + default=False, + help="adds projection + glu to targets", + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + default=1.0, + help="multiply feature extractor var grads by this", + ) + + # masking + parser.add_argument("--mask-length", type=int, default=10, help="mask_length") + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-overlap", + type=bool, + default=False, + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # channel masking + parser.add_argument( + "--mask-channel-length", + type=int, + default=10, + help="length of the mask for features (channels)", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a feature with 0", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length for channel masking", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + type=bool, + default=False, + help="whether to allow channel masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # positional embeddings + parser.add_argument( + "--conv-pos", + type=int, + default=128, + help="number of filters for convolutional positional embeddings", + ) + + parser.add_argument( + "--conv-pos-groups", + type=int, + default=16, + help="number of groups for convolutional positional embedding", + ) + + parser.add_argument( + "--conv-pos-batch-norm", + type=bool, + default=False, + help="use batch norm instead of weight norm in conv_pos (for bf16 models)", + ) + + parser.add_argument( + "--latent-temp", + type=float, + nargs="*", + default=[2, 0.5, 0.999995], + help="legacy (to be removed)", + ) + + # loss computation + parser.add_argument( + "--skip-masked", + type=bool, + default=False, + help="skip computing losses over masked frames", + ) + + parser.add_argument( + "--skip-nomask", + type=bool, + default=False, + help="skip computing losses over unmasked frames", + ) + + parser.add_argument( + "--checkpoint-activations", + type=bool, + default=False, + help="recompute activations and save memory for extra compute", + ) + + parser.add_argument( + "--pred-masked-weight", + type=float, + default=1, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--pred-nomask-weight", + type=float, + default=0, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--loss-weights", + type=float, + nargs="*", + default=[10], + help="weight for masked part in ssl loss", + ) + + # FP16 optimization + parser.add_argument( + "--required-seq-len-multiple", + type=int, + default=2, + help="pad the input to encoder such that the sequence length is divisible by multiple", + ) + + parser.add_argument( + "--attn-type", type=str, default="", help="if espnet use ESPNET MHA" + ) + + parser.add_argument( + "--pos-enc-type", + type=str, + default="abs", + help="Positional encoding type to use in conformer", + ) + + parser.add_argument( + "--num-classes", + type=int, + nargs="*", + default=[504], + help="""num class, a little larger than the number of cluster, + the largest is for padding, + and the value should be the multiple of 4, for faster computation""", + ) + + +class HubertModel(nn.Module): + def __init__( + self, + cfg, + ) -> None: + super().__init__() + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.untie_final_proj = cfg.untie_final_proj + if self.untie_final_proj: + self.final_proj = nn.Linear( + cfg.encoder_embed_dim, final_dim * len(cfg.num_classes) + ) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + # modules below are not needed during fine-tuning + self.num_classes = cfg.num_classes + self.label_embs_concat = nn.Parameter( + torch.FloatTensor(sum(self.num_classes), final_dim) + ) + self.pred_masked_weight = cfg.pred_masked_weight + self.pred_nomask_weight = cfg.pred_nomask_weight + self.loss_weights = cfg.loss_weights + nn.init.uniform_(self.label_embs_concat) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb.to(x.dtype) + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits + + def forward_features(self, source: torch.Tensor) -> torch.Tensor: + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets( + self, + features: torch.Tensor, + target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + return padding_mask + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ): + """output layer is 1-based""" + features = self.forward_features(source) + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1, + ) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + def compute_pred(proj_x, target, label_embs): + # compute logits for the i-th label set + y = torch.index_select(label_embs, 0, target.long()) + negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + # proj_x: (S, D) + # y: (S, D) + # negs: (Neg, S, D) + return self.compute_nce(proj_x, y, negs) + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + if self.untie_final_proj: + proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) + else: + proj_x_m_list = [proj_x_m for _ in range(len(target_list))] + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) + ] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + if self.untie_final_proj: + proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) + else: + proj_x_u_list = [proj_x_u for _ in range(len(target_list))] + + logit_u_list = [ + compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) + for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list)) + ] + else: + logit_u_list = [None for _ in target_list] + + # result = { + # "logit_m_list": logit_m_list, + # "logit_u_list": logit_u_list, + # "padding_mask": padding_mask, + # "features_pen": features_pen, + # } + return self.compute_loss(logit_m_list, logit_u_list, features_pen) + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None + + def compute_loss(self, logit_m_list, logit_u_list, features_pen): + loss = 0.0 + sample_size = 0 + logging_output = {} + reduce = True + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list = [x.float() for x in logit_m_list if x is not None] + targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list] + assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 + for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): + loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_{i}"] = loss_m.detach().item() + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += targ_m_list[0].numel() + + loss_u_list = [] + logp_u_list = [x.float() for x in logit_u_list if x is not None] + targ_u_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_u_list] + assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 + for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): + loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_{i}"] = loss_u.detach().item() + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += targ_u_list[0].numel() + + if self.loss_weights is not None: + extra_losses = [] + names = [] + extra_losses.append(features_pen) + names.append("features_pen") + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + logging_output = { + "loss": loss.item() if reduce else loss, + **logging_output, + } + + # for lk in self.log_keys: + # if lk in net_output: + # logging_output[lk] = float((net_output[lk])) + + def compute_correct(logits): + if logits.numel() == 0: + return 0, 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + return corr, count + + with torch.no_grad(): + for i, logp_m in enumerate(logp_m_list): + corr_m, count_m = compute_correct(logp_m) + logging_output[f"correct_m_{i}"] = corr_m + logging_output[f"count_m_{i}"] = count_m + + for i, logp_u in enumerate(logp_u_list): + corr_u, count_u = compute_correct(logp_u) + logging_output[f"correct_u_{i}"] = corr_u + logging_output[f"count_u_{i}"] = count_u + + return loss, sample_size, logging_output diff --git a/egs/librispeech/SSL/hubert/hubert_ce.py b/egs/librispeech/SSL/hubert/hubert_ce.py new file mode 100644 index 000000000..ccdd63efd --- /dev/null +++ b/egs/librispeech/SSL/hubert/hubert_ce.py @@ -0,0 +1,940 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils import GradMultiply, LayerNorm +from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError(f"this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + + +def add_hubert_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--label-rate", + type=float, + default=50, + ) + + parser.add_argument( + "--sample-rate", + type=float, + default=16000, + ) + + parser.add_argument( + "--extractor-mode", + type=str, + default="default", + help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group + norm with d groups in the first conv block, whereas layer_norm + has layer norms in every block (meant to use with normalize=True)""", + ) + parser.add_argument( + "--encoder-layers", + type=int, + default=12, + help="num encoder layers in the transformer", + ) + + parser.add_argument( + "--encoder-embed-dim", + type=int, + default=768, + help="encoder embedding dimension", + ) + + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + default=3072, + help="encoder embedding dimension for FFN", + ) + + parser.add_argument( + "--encoder-attention-heads", + type=int, + default=12, + help="num encoder attention heads", + ) + + parser.add_argument( + "--activation-fn", + type=str, + choices=[ + "relu", + "gelu", + "gelu_fast", + "gelu_accurate", + "tanh", + "linear", + ], + default="gelu", + help="activation function to use", + ) + + parser.add_argument( + "--layer-type", + type=str, + choices=["transformer", "conformer", "trf_adp"], + default="transformer", + help="layer type in encoder", + ) + + # dropouts + parser.add_argument( + "--dropout", + type=float, + default=0.1, + help="dropout probability for the transformer", + ) + + parser.add_argument( + "--attention-dropout", + type=float, + default=0.1, + help="dropout probability for attention weights", + ) + + parser.add_argument( + "--activation-dropout", + type=float, + default=0.0, + help="dropout probability after activation in FFN", + ) + + parser.add_argument( + "--encoder-layerdrop", + type=float, + default=0.0, + help="probability of dropping a tarnsformer layer", + ) + + parser.add_argument( + "--dropout-input", + type=float, + default=0.0, + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + default=0.0, + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--final-dim", + type=int, + default=0, + help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0", + ) + + parser.add_argument( + "--untie-final-proj", + type=bool, + default=False, + help="use separate projection for each target", + ) + + parser.add_argument( + "--layer-norm-first", + type=bool, + default=False, + help="apply layernorm first in the transformer", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--conv-bias", + type=bool, + default=False, + help="include bias in conv encoder", + ) + + parser.add_argument( + "--logit-temp", + type=float, + default=0.1, + help="temperature to divide logits by", + ) + + parser.add_argument( + "--target-glu", + type=bool, + default=False, + help="adds projection + glu to targets", + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + default=1.0, + help="multiply feature extractor var grads by this", + ) + + # masking + parser.add_argument("--mask-length", type=int, default=10, help="mask_length") + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-overlap", + type=bool, + default=False, + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # channel masking + parser.add_argument( + "--mask-channel-length", + type=int, + default=10, + help="length of the mask for features (channels)", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a feature with 0", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length for channel masking", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + type=bool, + default=False, + help="whether to allow channel masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # positional embeddings + parser.add_argument( + "--conv-pos", + type=int, + default=128, + help="number of filters for convolutional positional embeddings", + ) + + parser.add_argument( + "--conv-pos-groups", + type=int, + default=16, + help="number of groups for convolutional positional embedding", + ) + + parser.add_argument( + "--conv-pos-batch-norm", + type=bool, + default=False, + help="use batch norm instead of weight norm in conv_pos (for bf16 models)", + ) + + parser.add_argument( + "--latent-temp", + type=float, + nargs="*", + default=[2, 0.5, 0.999995], + help="legacy (to be removed)", + ) + + # loss computation + parser.add_argument( + "--skip-masked", + type=bool, + default=False, + help="skip computing losses over masked frames", + ) + + parser.add_argument( + "--skip-nomask", + type=bool, + default=False, + help="skip computing losses over unmasked frames", + ) + + parser.add_argument( + "--checkpoint-activations", + type=bool, + default=False, + help="recompute activations and save memory for extra compute", + ) + + parser.add_argument( + "--pred-masked-weight", + type=float, + default=1, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--pred-nomask-weight", + type=float, + default=0, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--loss-weights", + type=float, + nargs="*", + default=[10], + help="weight for masked part in ssl loss", + ) + + # FP16 optimization + parser.add_argument( + "--required-seq-len-multiple", + type=int, + default=2, + help="pad the input to encoder such that the sequence length is divisible by multiple", + ) + + parser.add_argument( + "--attn-type", type=str, default="", help="if espnet use ESPNET MHA" + ) + + parser.add_argument( + "--pos-enc-type", + type=str, + default="abs", + help="Positional encoding type to use in conformer", + ) + + parser.add_argument( + "--num-classes", + type=int, + nargs="*", + default=[504], + help="""num class, a little larger than the number of cluster, + the largest is for padding, + and the value should be the multiple of 4, for faster computation""", + ) + + +class HubertModel(nn.Module): + def __init__( + self, + cfg, + ) -> None: + super().__init__() + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.untie_final_proj = cfg.untie_final_proj + self.final_proj = nn.Linear(cfg.encoder_embed_dim, sum(cfg.num_classes)) + + # modules below are not needed during fine-tuning + self.num_classes = cfg.num_classes + self.pred_masked_weight = cfg.pred_masked_weight + self.pred_nomask_weight = cfg.pred_nomask_weight + self.loss_weights = cfg.loss_weights + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb.to(x.dtype) + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_features(self, source: torch.Tensor) -> torch.Tensor: + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets( + self, + features: torch.Tensor, + target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + return padding_mask + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ): + """output layer is 1-based""" + features = self.forward_features(source) + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1, + ) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + proj_x_m /= self.logit_temp + logit_m_list = [proj_x_m for _ in range(len(target_list))] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + proj_x_u /= self.logit_temp + logit_u_list = [proj_x_u for _ in range(len(target_list))] + else: + logit_u_list = [None for _ in target_list] + + # result = { + # "logit_m_list": logit_m_list, + # "logit_u_list": logit_u_list, + # "padding_mask": padding_mask, + # "features_pen": features_pen, + # } + targ_m_list = target_list[0][masked_indices] + targ_m_list = targ_m_list.long() + targ_m_list = [targ_m_list for _ in range(len(target_list))] + + targ_u_list = target_list[0][nomask_indices] + targ_u_list = targ_u_list.long() + targ_u_list = [targ_u_list for _ in range(len(target_list))] + return self.compute_loss( + logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen + ) + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.final_proj = None + + def compute_loss( + self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen + ): + loss = 0.0 + sample_size = 0 + logging_output = {} + reduce = True + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list = [x.float() for x in logit_m_list if x is not None] + logp_m_list = torch.cat(logp_m_list) + targ_m_list = torch.cat(targ_m_list) + + loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_0"] = loss_m.detach().item() + + assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += len(targ_m_list) + + loss_u_list = [] + logp_u_list = [x.float() for x in logit_u_list if x is not None] + logp_u_list = torch.cat(logp_u_list) + targ_u_list = torch.cat(targ_u_list) + + loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_0"] = loss_u.detach().item() + + assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += len(targ_u_list) + + if self.loss_weights is not None: + extra_losses = [] + names = [] + extra_losses.append(features_pen) + names.append("features_pen") + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + logging_output = { + "loss": loss.item() if reduce else loss, + **logging_output, + } + + # for lk in self.log_keys: + # if lk in net_output: + # logging_output[lk] = float((net_output[lk])) + + def compute_correct(logits, target): + if logits.numel() == 0: + return 0, 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == target + min = logits.argmin(-1) == target + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + return corr, count + + with torch.no_grad(): + corr_m, count_m = compute_correct(logp_m_list, targ_m_list) + logging_output[f"correct_m_0"] = corr_m + logging_output[f"count_m_0"] = count_m + + corr_u, count_u = compute_correct(logp_u_list, targ_u_list) + logging_output[f"correct_u_0"] = corr_u + logging_output[f"count_u_0"] = count_u + + return loss, sample_size, logging_output diff --git a/egs/librispeech/SSL/hubert/joiner.py b/egs/librispeech/SSL/hubert/joiner.py new file mode 120000 index 000000000..aa3362cda --- /dev/null +++ b/egs/librispeech/SSL/hubert/joiner.py @@ -0,0 +1 @@ +../../ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py new file mode 100644 index 000000000..46a968b69 --- /dev/null +++ b/egs/librispeech/SSL/hubert/model.py @@ -0,0 +1,344 @@ +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +class AsrModel(nn.Module): + def __init__( + self, + encoder, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + encoder_dim: int = 768, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder: + It is the transcription network in the paper. Its accepts + inputs: `x` of (N, T, encoder_dim). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + self.encoder = encoder + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.25 + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.25 + ) + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward_encoder( + self, + x: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 2-D tensor of shape (N, T). + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + if padding_mask is None: + padding_mask = torch.zeros_like(x, dtype=torch.bool) + + encoder_out, padding_mask = self.encoder.extract_features( + source=x, + padding_mask=padding_mask, + mask=self.encoder.training, + ) + encoder_out_lens = torch.sum(~padding_mask, dim=1) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + + return encoder_out, encoder_out_lens + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets, + input_lengths=encoder_out_lens, + target_lengths=target_lengths, + reduction="sum", + ) + return ctc_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + y: k2.RaggedTensor, + padding_mask: Optional[torch.Tensor] = None, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 2-D tensor of shape (N, T). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 2, x.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == y.dim0, (x.shape, y.dim0) + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + # Compute CTC loss + targets = y.values + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss, encoder_out_lens diff --git a/egs/librispeech/SSL/hubert/optim.py b/egs/librispeech/SSL/hubert/optim.py new file mode 120000 index 000000000..56b827b8a --- /dev/null +++ b/egs/librispeech/SSL/hubert/optim.py @@ -0,0 +1 @@ +../../ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py new file mode 100644 index 000000000..d9bda8857 --- /dev/null +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -0,0 +1,1082 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For hubert model pretraining: +./hubert/pretrain.py \ + --world-size 8 \ + --num-epochs 400 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir hubert/exp \ + --full-libri 1 \ + --max-duration 87.5 \ + --accum-grad 4 +""" + + +import argparse +import copy +import logging +import sys +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from hubert import HubertModel, add_hubert_arguments +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from ssl_datamodule import LibriSpeechDataModule +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.functional import pad +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * params.accum_grad + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=400, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hubert/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=5000, + help="Eden warmup steps", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0, + help="Eden warmup start learning rate", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=80, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--sanity-check", + type=str2bool, + default=False, + help="Check if any of the batches in epoch 1 would cause OOM.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=100000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accum-grad", + type=int, + default=4, + help="""update gradient when batch_idx_train % accum_grad == 0. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--max-keep-size", + type=int, + default=sys.maxsize, + help="exclude sample longer than this.", + ) + + parser.add_argument( + "--min-keep-size", + type=float, + default=32000, + help="exclude sample longer less than this.", + ) + + parser.add_argument( + "--max-sample-size", + type=float, + default=250000, + help="max sample size to crop to for batching.", + ) + + add_hubert_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of updates happen to the model so far across + epochs. + + - sub_batch_idx_train: It contains number of batch trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "sub_batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_model(params: AttributeDict) -> nn.Module: + model = HubertModel(params) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `dataset.HubertDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + kmeans = batch["kmeans"].to(device) + + with torch.set_grad_enabled(is_training): + loss, num_masked_tokens, logging_output = model( + source=audio, target_list=[kmeans], padding_mask=padding_mask + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = num_masked_tokens + for item in logging_output: + info[item] = logging_output[item] + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for sub_batch_idx, batch in enumerate(train_dl): + params.sub_batch_idx_train += 1 + batch_idx = sub_batch_idx // params.accum_grad + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + batch_size = batch["kmeans"].shape[0] + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss / params.accum_grad).backward() + + if sub_batch_idx % params.accum_grad == params.accum_grad - 1: + params.batch_idx_train += 1 + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if batch_idx % params.accum_grad != params.accum_grad - 1: + optimizer.zero_grad() + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + params.warmup_batches, + params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechDataModule(args) + + train_cuts = ( + librispeech.train_all_shuf_cuts() + if params.full_libri + else librispeech.train_clean_100_cuts() + ) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if ( + c.duration < params.min_keep_size / params.sample_rate + or c.duration > params.max_keep_size / params.sample_rate + ): + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + max_sample_size=params.max_sample_size, + sample_rate=params.sample_rate, + label_rate=params.label_rate, + random_crop=params.random_crop, + pad_audio=False, + num_classes=params.num_classes, + do_normalize=params.do_normalize, + sampler_state_dict=sampler_state_dict, + ) + + valid_cuts = librispeech.dev_clean_cuts() + # valid_cuts += librispeech.dev_other_cuts() + valid_cuts = valid_cuts.filter(remove_short_and_long_utt) + + valid_dl = librispeech.valid_dataloaders( + valid_cuts, + max_sample_size=params.max_sample_size, + sample_rate=params.sample_rate, + label_rate=params.label_rate, + random_crop=params.random_crop, + pad_audio=False, + num_classes=params.num_classes, + do_normalize=params.do_normalize, + ) + + if params.sanity_check and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `dataset.HubertDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + audio = batch["audio"] + logging.info(f"audio shape: {audio.shape}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py new file mode 100644 index 000000000..24c0d4d3a --- /dev/null +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -0,0 +1,1082 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For hubert model pretraining: +./hubert/pretrain.py \ + --world-size 8 \ + --num-epochs 400 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir hubert/exp \ + --full-libri 1 \ + --max-duration 87.5 \ + --accum-grad 4 +""" + + +import argparse +import copy +import logging +import sys +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from hubert_ce import HubertModel, add_hubert_arguments +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from ssl_datamodule import LibriSpeechDataModule +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.functional import pad +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * params.accum_grad + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=400, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hubert/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=5000, + help="Eden warmup steps", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0, + help="Eden warmup start learning rate", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=80, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--sanity-check", + type=str2bool, + default=False, + help="Check if any of the batches in epoch 1 would cause OOM.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=100000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accum-grad", + type=int, + default=4, + help="""update gradient when batch_idx_train % accum_grad == 0. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--max-keep-size", + type=int, + default=sys.maxsize, + help="exclude sample longer than this.", + ) + + parser.add_argument( + "--min-keep-size", + type=float, + default=32000, + help="exclude sample longer less than this.", + ) + + parser.add_argument( + "--max-sample-size", + type=float, + default=250000, + help="max sample size to crop to for batching.", + ) + + add_hubert_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of updates happen to the model so far across + epochs. + + - sub_batch_idx_train: It contains number of batch trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "sub_batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_model(params: AttributeDict) -> nn.Module: + model = HubertModel(params) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `dataset.HubertDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + kmeans = batch["kmeans"].to(device) + + with torch.set_grad_enabled(is_training): + loss, num_masked_tokens, logging_output = model( + source=audio, target_list=[kmeans], padding_mask=padding_mask + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = num_masked_tokens + for item in logging_output: + info[item] = logging_output[item] + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for sub_batch_idx, batch in enumerate(train_dl): + params.sub_batch_idx_train += 1 + batch_idx = sub_batch_idx // params.accum_grad + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + batch_size = batch["kmeans"].shape[0] + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss / params.accum_grad).backward() + + if sub_batch_idx % params.accum_grad == params.accum_grad - 1: + params.batch_idx_train += 1 + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if batch_idx % params.accum_grad != params.accum_grad - 1: + optimizer.zero_grad() + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + params.warmup_batches, + params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechDataModule(args) + + train_cuts = ( + librispeech.train_all_shuf_cuts() + if params.full_libri + else librispeech.train_clean_100_cuts() + ) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if ( + c.duration < params.min_keep_size / params.sample_rate + or c.duration > params.max_keep_size / params.sample_rate + ): + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + max_sample_size=params.max_sample_size, + sample_rate=params.sample_rate, + label_rate=params.label_rate, + random_crop=params.random_crop, + pad_audio=False, + num_classes=params.num_classes, + do_normalize=params.do_normalize, + sampler_state_dict=sampler_state_dict, + ) + + valid_cuts = librispeech.dev_clean_cuts() + # valid_cuts += librispeech.dev_other_cuts() + valid_cuts = valid_cuts.filter(remove_short_and_long_utt) + + valid_dl = librispeech.valid_dataloaders( + valid_cuts, + max_sample_size=params.max_sample_size, + sample_rate=params.sample_rate, + label_rate=params.label_rate, + random_crop=params.random_crop, + pad_audio=False, + num_classes=params.num_classes, + do_normalize=params.do_normalize, + ) + + if params.sanity_check and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `dataset.HubertDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + audio = batch["audio"] + logging.info(f"audio shape: {audio.shape}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/hubert/scaling.py b/egs/librispeech/SSL/hubert/scaling.py new file mode 120000 index 000000000..e30bd99de --- /dev/null +++ b/egs/librispeech/SSL/hubert/scaling.py @@ -0,0 +1 @@ +../../ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/SSL/hubert/ssl_datamodule.py b/egs/librispeech/SSL/hubert/ssl_datamodule.py new file mode 100644 index 000000000..ac1a0997d --- /dev/null +++ b/egs/librispeech/SSL/hubert/ssl_datamodule.py @@ -0,0 +1,341 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from dataset import HubertDataset +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechDataModule: + """ + DataModule for SSL experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in SSL + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + + This class should be derived for specific corpora used in SSL tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="SSL data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/kmeans"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=float, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + group.add_argument( + "--do-normalize", + type=str2bool, + default=True, + help="whether to normalize the data", + ) + group.add_argument( + "--random-crop", + type=str2bool, + default=True, + help="always crop from the beginning if false", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + max_sample_size: Optional[int] = None, + sample_rate: float = 16000, + label_rate: float = 50, + random_crop: bool = True, + pad_audio: bool = False, + num_classes: list = [504], + do_normalize: bool = True, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = HubertDataset( + max_sample_size=max_sample_size, + sample_rate=sample_rate, + label_rate=label_rate, + random_crop=random_crop, + pad_audio=pad_audio, + num_classes=num_classes, + do_normalize=do_normalize, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + max_sample_size: Optional[int] = None, + sample_rate: float = 16000, + label_rate: float = 50, + random_crop: bool = True, + pad_audio: bool = False, + num_classes: list = [504], + do_normalize: bool = True, + ) -> DataLoader: + logging.info("About to create dev dataset") + validate = HubertDataset( + max_sample_size=max_sample_size, + sample_rate=sample_rate, + label_rate=label_rate, + random_crop=random_crop, + pad_audio=pad_audio, + num_classes=num_classes, + do_normalize=do_normalize, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders( + self, + cuts: CutSet, + sample_rate: float = 16000, + label_rate: float = 50, + random_crop: bool = True, + pad_audio: bool = False, + num_classes: list = [504], + do_normalize: bool = True, + ) -> DataLoader: + logging.debug("About to create test dataset") + test = HubertDataset( + sample_rate=sample_rate, + label_rate=label_rate, + random_crop=random_crop, + pad_audio=pad_audio, + num_classes=num_classes, + do_normalize=do_normalize, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + train_clean_100_cuts = self.train_clean_100_cuts() + train_clean_360_cuts = self.train_clean_360_cuts() + train_other_500_cuts = self.train_other_500_cuts() + return CutSet.mux( + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + weights=[ + 28539, # len(train_clean_100_cuts) + 104014, # len(train_clean_360_cuts) + 148688, # len(train_other_500_cuts) + ], + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/librispeech/SSL/hubert/utils.py b/egs/librispeech/SSL/hubert/utils.py new file mode 100644 index 000000000..de980ba62 --- /dev/null +++ b/egs/librispeech/SSL/hubert/utils.py @@ -0,0 +1,338 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def relu_squared(x: torch.Tensor): + return F.relu(x).pow(2) + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == "xla" + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + +def pad_to_multiple(x, multiple, dim=-1, value=0): + # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 + if x is None: + return None, 0 + tsz = x.size(dim) + m = tsz / multiple + remainder = math.ceil(m) * multiple - tsz + if m.is_integer(): + return x, 0 + pad_offset = (0,) * (-1 - dim) * 2 + + return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str) -> Callable: + """Returns the activation function corresponding to `activation`""" + if activation == "relu": + return F.relu + elif activation == "relu_squared": + return relu_squared + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "swish": + return torch.nn.SiLU + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class SamePad2d(nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + assert len(x.size()) == 4 + if self.remove > 0: + x = x[:, :, : -self.remove, : -self.remove] + return x + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None, tranpose_dim=-2): + super().__init__() + self.deconstruct_idx = deconstruct_idx + self.tranpose_dim = tranpose_dim + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(self.tranpose_dim, -1) + + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + has_fused_layernorm = False + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + export = True + if not export and torch.cuda.is_available() and has_fused_layernorm: + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class FairseqDropout(nn.Module): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace: bool = False): + if self.p > 0 and (self.training or self.apply_during_inference): + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + def make_generation_fast_( + self, + name: str, + retain_dropout: bool = False, + retain_dropout_modules: Optional[List[str]] = None, + **kwargs + ): + if retain_dropout: + if retain_dropout_modules is not None and self.module_name is None: + pass + elif ( + retain_dropout_modules is None # if None, apply to all modules + or self.module_name in retain_dropout_modules + ): + self.apply_during_inference = True + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/egs/librispeech/SSL/hubert/wav2vec2_module.py b/egs/librispeech/SSL/hubert/wav2vec2_module.py new file mode 100644 index 000000000..4c2e1ce98 --- /dev/null +++ b/egs/librispeech/SSL/hubert/wav2vec2_module.py @@ -0,0 +1,593 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from attention_module import MultiheadAttention, init_bert_params +from utils import ( + Fp32GroupNorm, + Fp32LayerNorm, + LayerNorm, + SamePad, + TransposeLast, + get_activation_fn, + index_put, + pad_to_multiple, +) + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x + + +def make_conv_pos(e, k, g, is_batch_norm=False): + pos_conv = nn.Conv1d( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) + nn.init.normal_(pos_conv.weight, mean=0, std=std) + nn.init.constant_(pos_conv.bias, 0) + + if not is_batch_norm: + pos_conv = nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2) + pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) + else: + batch_norm = nn.BatchNorm1d(e) + pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU()) + + return pos_conv + + +class TransformerEncoder(nn.Module): + def build_encoder_layer(self, args, **kwargs): + if args.layer_type == "transformer": + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + elif args.layer_type == "trf_adp": + use_adp = False + if args.adp_trf_idx == "all": + use_adp = True + else: + adp_trf_idx = list( + range(*[int(g) for g in args.adp_trf_idx.split(":")]) + ) + if kwargs.get("layer_idx", None) in adp_trf_idx: + use_adp = True + if use_adp: + layer = TransformerSentenceEncoderWithAdapterLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + adapter_num=args.adp_num, + adapter_dim=args.adp_dim, + adapter_act_fn=args.adp_act_fn, + ) + else: + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + + # layer = fsdp_wrap(layer) + # if args.checkpoint_activations: + # layer = checkpoint_wrapper(layer) + return layer + + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.required_seq_len_multiple = args.required_seq_len_multiple + + pos_conv_depth = getattr(args, "pos_conv_depth", 1) + if pos_conv_depth > 1: + num_layers = args.pos_conv_depth + k = max(3, args.conv_pos // num_layers) + + def make_conv_block(e, k, g, l): + return nn.Sequential( + *[ + nn.Sequential( + nn.Conv1d( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ), + SamePad(k), + TransposeLast(), + LayerNorm(e, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ) + for _ in range(l) + ] + ) + + self.pos_conv = make_conv_block( + self.embedding_dim, k, args.conv_pos_groups, num_layers + ) + + else: + self.pos_conv = make_conv_pos( + self.embedding_dim, + args.conv_pos, + args.conv_pos_groups, + is_batch_norm=args.conv_pos_batch_norm + if hasattr(args, "conv_pos_batch_norm") + else False, + ) + + self.layers = nn.ModuleList( + [ + self.build_encoder_layer(args, layer_idx=ii) + for ii in range(args.encoder_layers) + ] + ) + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, layer=None, corpus_key=None): + x, layer_results = self.extract_features( + x, padding_mask, layer, corpus_key=corpus_key + ) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features( + self, + x, + padding_mask=None, + tgt_layer=None, + min_layer=0, + corpus_key=None, + ): + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + # pad to the sequence length dimension + x, pad_length = pad_to_multiple( + x, self.required_seq_len_multiple, dim=-2, value=0 + ) + if pad_length > 0 and padding_mask is None: + padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) + padding_mask[:, -pad_length:] = True + else: + padding_mask, _ = pad_to_multiple( + padding_mask, self.required_seq_len_multiple, dim=-1, value=True + ) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + r = None + + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() if self.layerdrop > 0 else 1 + if not self.training or (dropout_probability > self.layerdrop): + layer_check = layer + # if isinstance(layer, FullyShardedDataParallel): + # layer_check = layer.unwrapped_module + if (corpus_key is None) or ( + not isinstance( + layer_check, + (TransformerSentenceEncoderWithAdapterLayer,), + ) + ): + x, (z, lr) = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + ) + else: + x, (z, lr) = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + corpus_key=corpus_key, + ) + if i >= min_layer: + layer_results.append((x, z, lr)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + # undo paddding + if pad_length > 0: + x = x[:, :-pad_length] + + def undo_pad(a, b, c): + return ( + a[:-pad_length], + b[:-pad_length] if b is not None else b, + c[:-pad_length], + ) + + layer_results = [undo_pad(*u) for u in layer_results] + + return x, layer_results + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + ) -> None: + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask, + need_weights=False, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, (attn, layer_result) + + +class AdapterFast(nn.Module): + def __init__(self, adapter_num, input_dim, hidden_dim, act_fn): + """ + Implements adapter modules directly with 3D tensor weight as parameters + and without using ModuleList orto speed up training throughput. + """ + super().__init__() + + self.adapter_num = adapter_num + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim)) + self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim)) + self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim)) + self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim)) + + self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim)) + self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim)) + self.act_fn = nn.Identity() + if act_fn == "relu": + self.act_fn = nn.ReLU() + elif act_fn == "gelu": + self.act_fn = nn.GELU() + elif act_fn == "selu": + self.act_fn = nn.SELU() + else: + raise ValueError(f"unsupported {act_fn}") + + self.input_dim = input_dim + self.reset_parameters() + + def reset_parameters(self): + for ii in range(self.adapter_num): + nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii]) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.b_a[ii], -bound, bound) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii]) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.b_b[ii], -bound, bound) + + nn.init.ones_(self.ln_W) + nn.init.zeros_(self.ln_b) + + def forward(self, x, adapter_id): + ii = adapter_id + h = x + h = F.layer_norm(h, (self.input_dim,), self.ln_W[ii], self.ln_b[ii]) + h = F.linear(h, self.W_a[ii], self.b_a[ii]) + h = self.act_fn(h) + h = F.linear(h, self.W_b[ii], self.b_b[ii]) + outputs = h + return outputs + + def extra_repr(self): + return "adapter={}, input_dim={}, hidden_dim={}".format( + self.adapter_num, self.input_dim, self.hidden_dim + ) + + +class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer): + """ + Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained + models. An adapter module is added along with vanilla Transformer module. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + adapter_num=201, + adapter_dim=64, + adapter_act_fn="relu", + ) -> None: + super().__init__( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + layer_norm_first=layer_norm_first, + ) + + self.adapter_num = adapter_num + self.adapter_dim = adapter_dim + self.adapter_layer = AdapterFast( + adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn + ) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + corpus_key=None, + ): + x, (attn, layer_result) = super().forward( + x=x, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + att_args=att_args, + ) + assert corpus_key is not None + assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}" + y = self.adapter_layer(x, corpus_key[0]) + x = x + y + return x, (attn, layer_result) diff --git a/egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py b/egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py new file mode 100644 index 000000000..aa2f45f75 --- /dev/null +++ b/egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py @@ -0,0 +1,52 @@ +import os + +import jsonlines +from tqdm import tqdm + +os.system( + "cp /userhome/user/yfy62/librispeech_data/data4ssl/manifests/librispeech_*_dev-clean* ." +) +os.system( + "cp /userhome/user/yfy62/librispeech_data/data4ssl/manifests/librispeech_*_train* ." +) +os.system("chmod -R 644 *.jsonl.gz") +os.system("gunzip *.gz") + +dataset_parts = ( + "dev-clean", + "train-clean-100", + "train-clean-360", + "train-other-500", +) + +kmeans_dir = "/userhome/user/yangguanrou/data/k500" +idx_dir = "/userhome/user/yangguanrou/data/shu" + +kmeans = [] +idxs = [] +for part in ["train", "valid"]: + with open(kmeans_dir + "/" + part + ".km", "r") as f: + kmeans += f.read().splitlines() + + with open(idx_dir + "/" + part + ".tsv", "r") as f: + lines = f.read().splitlines() + idxs += [ + line.split("\t", -1)[0].split("/", -1)[-1].replace(".flac", "") + for line in lines + if ".flac" in line + ] + +idx2kmeans = {} +for idx, km in zip(idxs, kmeans): + idx2kmeans[idx] = km + +for part in dataset_parts: + with jsonlines.open(f"librispeech_supervisions_{part}.jsonl") as reader: + with jsonlines.open( + f"librispeech_supervisions_{part}_new.jsonl", mode="w" + ) as writer: + for obj in tqdm(reader): + obj["custom"] = {"kmeans": idx2kmeans[obj["id"]]} + writer.write(obj) + +os.system('for file in *_new.jsonl; do mv "$file" "${file%_new.jsonl}.jsonl"; done') diff --git a/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py new file mode 100644 index 000000000..4212cd9c6 --- /dev/null +++ b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py @@ -0,0 +1,18 @@ +# simple script to convert a fairseq checkpoint into pytorch parameter state dict +from argparse import ArgumentParser +from collections import OrderedDict + +import torch + +parser = ArgumentParser() +parser.add_argument("--src") +parser.add_argument("--tgt") + +args = parser.parse_args() +src = args.src +tgt = args.tgt + +old_checkpoint = torch.load(src) +new_checkpoint = OrderedDict() +new_checkpoint["model"] = old_checkpoint["model"] +torch.save(new_checkpoint, tgt) diff --git a/egs/librispeech/SSL/local/prepare_char.py b/egs/librispeech/SSL/local/prepare_char.py new file mode 100644 index 000000000..8cc0502c2 --- /dev/null +++ b/egs/librispeech/SSL/local/prepare_char.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/text, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +import re +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: + """Check if all the given tokens are in token symbol table. + + Args: + token_sym_table: + Token symbol table that contains all the valid tokens. + tokens: + A list of tokens. + Returns: + Return True if there is any token not in the token_sym_table, + otherwise False. + """ + for tok in tokens: + if tok not in token_sym_table: + return True + return False + + +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: + """Generate a lexicon from a word list and token_sym_table. + + Args: + token_sym_table: + Token symbol table that mapping token to token ids. + words: + A list of strings representing words. + Returns: + Return a dict whose keys are words and values are the corresponding + tokens. + """ + lexicon = [] + for word in words: + chars = list(word.strip(" \t")) + if contain_oov(token_sym_table, chars): + continue + lexicon.append((word, chars)) + + # The OOV word is + lexicon.append(("", [""])) + return lexicon + + +def generate_tokens(text_file: str) -> Dict[str, int]: + """Generate tokens from the given text file. + + Args: + text_file: + A file that contains text lines to generate tokens. + Returns: + Return a dict whose keys are tokens and values are token ids ranged + from 0 to len(keys) - 1. + """ + tokens: Dict[str, int] = dict() + tokens[""] = 0 + tokens[""] = 1 + tokens[""] = 2 + whitespace = re.compile(r"([ \t\r\n]+)") + with open(text_file, "r", encoding="utf-8") as f: + for line in f: + line = re.sub(whitespace, "", line) + chars = list(line) + for char in chars: + if char not in tokens: + tokens[char] = len(tokens) + return tokens + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + text_file = lang_dir / "text" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", "", "#0", "", ""] + for w in excluded: + if w in words: + words.remove(w) + + token_sym_table = generate_tokens(text_file) + + lexicon = generate_lexicon(token_sym_table, words) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/local/prepare_lang.py b/egs/librispeech/SSL/local/prepare_lang.py new file mode 100644 index 000000000..c8cf9b881 --- /dev/null +++ b/egs/librispeech/SSL/local/prepare_lang.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import math +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch + +from icefall.lexicon import read_lexicon, write_lexicon + +Lexicon = List[Tuple[str, List[str]]] + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "SIL", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + sil_token = token2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [token2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") + return parser.parse_args() + + +def main(): + out_dir = Path(get_args().lang_dir) + lexicon_filename = out_dir / "lexicon.txt" + sil_token = "SIL" + sil_prob = 0.5 + + lexicon = read_lexicon(lexicon_filename) + tokens = get_tokens(lexicon) + words = get_words(lexicon) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") + + assert "" not in tokens + tokens = [""] + tokens + + assert "" not in words + assert "#0" not in words + assert "" not in words + assert "" not in words + + words = [""] + words + ["#0", "", ""] + + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) + + write_mapping(out_dir / "tokens.txt", token2id) + write_mapping(out_dir / "words.txt", word2id) + write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), out_dir / "L.pt") + torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") + + if False: + # Just for debugging, will remove it + L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") + L_disambig.labels_sym = L.labels_sym + L_disambig.aux_labels_sym = L.aux_labels_sym + L.draw(out_dir / "L.png", title="L") + L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/local/process_librispeech4finetune.py b/egs/librispeech/SSL/local/process_librispeech4finetune.py new file mode 100644 index 000000000..09f4b8a3e --- /dev/null +++ b/egs/librispeech/SSL/local/process_librispeech4finetune.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import CutSet +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + return parser.parse_args() + + +def process_wav_librispeech( + dataset: Optional[str] = None, +): + src_dir = Path("data/manifests") + output_dir = Path("data/wav") + + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = "librispeech" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + process_wav_librispeech( + dataset=args.dataset, + ) diff --git a/egs/librispeech/SSL/local/process_librispeech4pretrain.py b/egs/librispeech/SSL/local/process_librispeech4pretrain.py new file mode 100644 index 000000000..c375a2df3 --- /dev/null +++ b/egs/librispeech/SSL/local/process_librispeech4pretrain.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import CutSet +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + return parser.parse_args() + + +def process_kmeans_librispeech( + dataset: Optional[str] = None, +): + src_dir = Path(".") + output_dir = Path(".") + + if dataset is None: + dataset_parts = ( + "dev-clean", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = "librispeech" + suffix = "jsonl" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}_raw.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + process_kmeans_librispeech( + dataset=args.dataset, + ) diff --git a/egs/librispeech/SSL/local/process_raw_cuts.py b/egs/librispeech/SSL/local/process_raw_cuts.py new file mode 100644 index 000000000..9d2ee5945 --- /dev/null +++ b/egs/librispeech/SSL/local/process_raw_cuts.py @@ -0,0 +1,23 @@ +import os + +import jsonlines +from tqdm import tqdm + +dataset_parts = ( + "dev-clean", + "train-clean-100", + "train-clean-360", + "train-other-500", +) + +for part in dataset_parts: + with jsonlines.open(f"librispeech_cuts_{part}_raw.jsonl") as reader: + with jsonlines.open(f"librispeech_cuts_{part}.jsonl", mode="w") as writer: + for obj in tqdm(reader): + obj["custom"] = {"kmeans": obj["supervisions"][0]["custom"]["kmeans"]} + del obj["supervisions"][0]["custom"] + + writer.write(obj) + +os.system("rm *_raw.jsonl") +os.system("gzip *.jsonl") diff --git a/egs/librispeech/SSL/shared b/egs/librispeech/SSL/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/librispeech/SSL/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/asr_datamodule.py b/egs/librispeech/SSL/zipformer/asr_datamodule.py new file mode 120000 index 000000000..21a701163 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../hubert/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/beam_search.py b/egs/librispeech/SSL/zipformer/beam_search.py new file mode 120000 index 000000000..f4d4b5732 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/beam_search.py @@ -0,0 +1 @@ +../../ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/dataset.py b/egs/librispeech/SSL/zipformer/dataset.py new file mode 120000 index 000000000..cb5aedde1 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/dataset.py @@ -0,0 +1 @@ +../hubert/dataset.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/decode.py b/egs/librispeech/SSL/zipformer/decode.py new file mode 100644 index 000000000..1562c28b8 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/decode.py @@ -0,0 +1,1043 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from finetune import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + + encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(batch["supervisions"]["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["cuts"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + dev_clean_dl = librispeech.test_dataloaders( + dev_clean_cuts, + do_normalize=params.do_normalize, + ) + dev_other_dl = librispeech.test_dataloaders( + dev_other_cuts, + do_normalize=params.do_normalize, + ) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders( + test_clean_cuts, + do_normalize=params.do_normalize, + ) + test_other_dl = librispeech.test_dataloaders( + test_other_cuts, + do_normalize=params.do_normalize, + ) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/zipformer/decoder.py b/egs/librispeech/SSL/zipformer/decoder.py new file mode 120000 index 000000000..a2138e5da --- /dev/null +++ b/egs/librispeech/SSL/zipformer/decoder.py @@ -0,0 +1 @@ +../../ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/encoder_interface.py b/egs/librispeech/SSL/zipformer/encoder_interface.py new file mode 120000 index 000000000..0afd669f2 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py new file mode 100644 index 000000000..bbb445320 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -0,0 +1,1551 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For HuBERT model finetuning: +./hubert/finetune.py \ + --world-size 8 \ + --num-epochs 200 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir hubert/exp \ + --full-libri 0 \ + --max-duration 1000 + +It supports finetuning with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from hubert_ce import HubertModel +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * params.accum_grad + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + # hubert parameters + parser.add_argument( + "--label-rate", + type=float, + default=50, + ) + + parser.add_argument( + "--sample-rate", + type=float, + default=16000, + ) + + parser.add_argument( + "--extractor-mode", + type=str, + default="default", + help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group + norm with d groups in the first conv block, whereas layer_norm + has layer norms in every block (meant to use with normalize=True)""", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--conv-bias", type=bool, default=False, help="include bias in conv encoder" + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + default=1.0, + help="multiply feature extractor var grads by this", + ) + + # masking + parser.add_argument("--mask-length", type=int, default=10, help="mask_length") + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-overlap", + type=bool, + default=False, + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # channel masking + parser.add_argument( + "--mask-channel-length", + type=int, + default=10, + help="length of the mask for features (channels)", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a feature with 0", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length for channel masking", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + type=bool, + default=False, + help="whether to allow channel masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # loss computation + parser.add_argument( + "--skip-masked", + type=bool, + default=False, + help="skip computing losses over masked frames", + ) + + parser.add_argument( + "--skip-nomask", + type=bool, + default=False, + help="skip computing losses over unmasked frames", + ) + + parser.add_argument( + "--checkpoint-activations", + type=bool, + default=False, + help="recompute activations and save memory for extra compute", + ) + + parser.add_argument( + "--pred-masked-weight", + type=float, + default=1, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--pred-nomask-weight", + type=float, + default=0, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--loss-weights", + type=float, + nargs="*", + default=[10], + help="weight for masked part in ssl loss", + ) + + # FP16 optimization + parser.add_argument( + "--required-seq-len-multiple", + type=int, + default=2, + help="pad the input to encoder such that the sequence length is divisible by multiple", + ) + + parser.add_argument( + "--attn-type", type=str, default="", help="if espnet use ESPNET MHA" + ) + + parser.add_argument( + "--pos-enc-type", + type=str, + default="abs", + help="Positional encoding type to use in conformer", + ) + + parser.add_argument( + "--logit-temp", type=float, default=0.1, help="temperature to divide logits by" + ) + + parser.add_argument( + "--dropout-input", + type=float, + default=0.0, + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + default=0.0, + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--num-classes", + type=int, + nargs="*", + default=[504], + help="""num class, a little larger than the number of cluster, + the largest is for padding, + and the value should be the multiple of 4, for faster computation""", + ) + + parser.add_argument( + "--untie-final-proj", + type=bool, + default=False, + help="use separate projection for each target", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=222, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hubert/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--pretrained-dir", + type=str, + help="""The pretrained model dir. + It specifies the directory where the pretrained checkpoint is saved.""", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.001, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--sanity-check", + type=str2bool, + default=False, + help="Check if any of the batches in epoch 1 would cause OOM.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=100000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accum-grad", + type=int, + default=1, + help="""update gradient when batch_idx_train % accum_grad == 0. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + + contains number of updates happen to the model so far across + epochs. + + - sub_batch_idx_train: It contains number of batch trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "sub_batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for pruned RNN-T loss + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + if hasattr(params, "pretrained_dir"): + logging.info(f"Loading {params.pretrained_dir}") + pretrained = torch.load(params.pretrained_dir) + encoder = HubertModel(params) + encoder.load_state_dict(pretrained["model"]) + else: + encoder = HubertModel(params) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `dataset.HubertAsrDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, num_frames = model( + x=audio, + padding_mask=padding_mask, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = num_frames.sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for sub_batch_idx, batch in enumerate(train_dl): + params.sub_batch_idx_train += 1 + batch_idx = sub_batch_idx // params.accum_grad + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss / params.accum_grad).backward() + + if sub_batch_idx % params.accum_grad == params.accum_grad - 1: + params.batch_idx_train += 1 + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if batch_idx % params.accum_grad != params.accum_grad - 1: + optimizer.zero_grad() + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=0) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = ( + librispeech.train_all_shuf_cuts() + if params.full_libri + else librispeech.train_clean_100_cuts() + ) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + do_normalize=params.do_normalize, + sampler_state_dict=sampler_state_dict, + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + + valid_dl = librispeech.valid_dataloaders( + valid_cuts, + do_normalize=params.do_normalize, + ) + + if params.sanity_check and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `dataset.HubertAsrDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + audio = batch["audio"] + logging.info(f"audio shape: {audio.shape}") + + y = sp.encode(batch["supervisions"]["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/zipformer/hubert_ce.py b/egs/librispeech/SSL/zipformer/hubert_ce.py new file mode 100644 index 000000000..ba4e1cddd --- /dev/null +++ b/egs/librispeech/SSL/zipformer/hubert_ce.py @@ -0,0 +1,601 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scaling import ScheduledFloat +from utils import GradMultiply, LayerNorm +from wav2vec2_module import ConvFeatureExtractionModel +from zipformer import Zipformer2 + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError(f"this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +class HubertModel(nn.Module): + def __init__( + self, + cfg, + ) -> None: + super().__init__() + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate + encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0] + encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim)) + self.post_extract_proj = ( + nn.Linear(self.embed, encoder_input_dim) + if self.embed != encoder_input_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_()) + + self.encoder = Zipformer2( + output_downsampling_factor=1, + downsampling_factor=_to_int_tuple(cfg.downsampling_factor), + num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers), + encoder_dim=_to_int_tuple(cfg.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(cfg.query_head_dim), + pos_head_dim=_to_int_tuple(cfg.pos_head_dim), + value_head_dim=_to_int_tuple(cfg.value_head_dim), + pos_dim=cfg.pos_dim, + num_heads=_to_int_tuple(cfg.num_heads), + feedforward_dim=_to_int_tuple(cfg.feedforward_dim), + cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + ) + + self.layer_norm = LayerNorm(self.embed) + + self.untie_final_proj = cfg.untie_final_proj + self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes)) + + # modules below are not needed during fine-tuning + self.num_classes = cfg.num_classes + self.pred_masked_weight = cfg.pred_masked_weight + self.pred_nomask_weight = cfg.pred_nomask_weight + self.loss_weights = cfg.loss_weights + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb.to(x.dtype) + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_features(self, source: torch.Tensor) -> torch.Tensor: + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets( + self, + features: torch.Tensor, + target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + return padding_mask + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ): + """output layer is 1-based""" + features = self.forward_features(source) + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float -> (T, B, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x = x.transpose(0, 1) + x, x_lens = self.encoder(x, ~padding_mask.sum(dim=-1)) + x = x.transpose(0, 1) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + proj_x_m /= self.logit_temp + logit_m_list = [proj_x_m for _ in range(len(target_list))] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + proj_x_u /= self.logit_temp + logit_u_list = [proj_x_u for _ in range(len(target_list))] + else: + logit_u_list = [None for _ in target_list] + + # result = { + # "logit_m_list": logit_m_list, + # "logit_u_list": logit_u_list, + # "padding_mask": padding_mask, + # "features_pen": features_pen, + # } + targ_m_list = target_list[0][masked_indices] + targ_m_list = targ_m_list.long() + targ_m_list = [targ_m_list for _ in range(len(target_list))] + + targ_u_list = target_list[0][nomask_indices] + targ_u_list = targ_u_list.long() + targ_u_list = [targ_u_list for _ in range(len(target_list))] + return self.compute_loss( + logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen + ) + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.final_proj = None + + def compute_loss( + self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen + ): + loss = 0.0 + sample_size = 0 + logging_output = {} + reduce = True + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list = [x.float() for x in logit_m_list if x is not None] + logp_m_list = torch.cat(logp_m_list) + targ_m_list = torch.cat(targ_m_list) + + loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_0"] = loss_m.detach().item() + + assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += len(targ_m_list) + + loss_u_list = [] + logp_u_list = [x.float() for x in logit_u_list if x is not None] + logp_u_list = torch.cat(logp_u_list) + targ_u_list = torch.cat(targ_u_list) + + loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_0"] = loss_u.detach().item() + + assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += len(targ_u_list) + + if self.loss_weights is not None: + extra_losses = [] + names = [] + extra_losses.append(features_pen) + names.append("features_pen") + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + logging_output = { + "loss": loss.item() if reduce else loss, + **logging_output, + } + + # for lk in self.log_keys: + # if lk in net_output: + # logging_output[lk] = float((net_output[lk])) + + def compute_correct(logits, target): + if logits.numel() == 0: + return 0, 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == target + min = logits.argmin(-1) == target + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + return corr, count + + with torch.no_grad(): + corr_m, count_m = compute_correct(logp_m_list, targ_m_list) + logging_output[f"correct_m_0"] = corr_m + logging_output[f"count_m_0"] = count_m + + corr_u, count_u = compute_correct(logp_u_list, targ_u_list) + logging_output[f"correct_u_0"] = corr_u + logging_output[f"count_u_0"] = count_u + + return loss, sample_size, logging_output diff --git a/egs/librispeech/SSL/zipformer/joiner.py b/egs/librispeech/SSL/zipformer/joiner.py new file mode 120000 index 000000000..aa3362cda --- /dev/null +++ b/egs/librispeech/SSL/zipformer/joiner.py @@ -0,0 +1 @@ +../../ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py new file mode 100644 index 000000000..46a968b69 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/model.py @@ -0,0 +1,344 @@ +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +class AsrModel(nn.Module): + def __init__( + self, + encoder, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + encoder_dim: int = 768, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder: + It is the transcription network in the paper. Its accepts + inputs: `x` of (N, T, encoder_dim). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + self.encoder = encoder + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.25 + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.25 + ) + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward_encoder( + self, + x: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 2-D tensor of shape (N, T). + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + if padding_mask is None: + padding_mask = torch.zeros_like(x, dtype=torch.bool) + + encoder_out, padding_mask = self.encoder.extract_features( + source=x, + padding_mask=padding_mask, + mask=self.encoder.training, + ) + encoder_out_lens = torch.sum(~padding_mask, dim=1) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + + return encoder_out, encoder_out_lens + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets, + input_lengths=encoder_out_lens, + target_lengths=target_lengths, + reduction="sum", + ) + return ctc_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + y: k2.RaggedTensor, + padding_mask: Optional[torch.Tensor] = None, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 2-D tensor of shape (N, T). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 2, x.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == y.dim0, (x.shape, y.dim0) + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + # Compute CTC loss + targets = y.values + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss, encoder_out_lens diff --git a/egs/librispeech/SSL/zipformer/optim.py b/egs/librispeech/SSL/zipformer/optim.py new file mode 120000 index 000000000..56b827b8a --- /dev/null +++ b/egs/librispeech/SSL/zipformer/optim.py @@ -0,0 +1 @@ +../../ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py new file mode 100644 index 000000000..5f547e0b8 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -0,0 +1,1380 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For hubert model pretraining: +./zipformer/pretrain.py \ + --world-size 8 \ + --num-epochs 400 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir hubert/exp \ + --full-libri 1 \ + --max-duration 87.5 \ + --accum-grad 4 +""" + + +import argparse +import copy +import logging +import sys +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from hubert_ce import HubertModel +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from ssl_datamodule import LibriSpeechDataModule +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * params.accum_grad + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + # hubert parameters + parser.add_argument( + "--label-rate", + type=float, + default=50, + ) + + parser.add_argument( + "--sample-rate", + type=float, + default=16000, + ) + + parser.add_argument( + "--extractor-mode", + type=str, + default="default", + help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group + norm with d groups in the first conv block, whereas layer_norm + has layer norms in every block (meant to use with normalize=True)""", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--conv-bias", type=bool, default=False, help="include bias in conv encoder" + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + default=1.0, + help="multiply feature extractor var grads by this", + ) + + # masking + parser.add_argument("--mask-length", type=int, default=10, help="mask_length") + + parser.add_argument( + "--mask-prob", + type=float, + default=0.65, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length", + ) + + parser.add_argument( + "--mask-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-overlap", + type=bool, + default=False, + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # channel masking + parser.add_argument( + "--mask-channel-length", + type=int, + default=10, + help="length of the mask for features (channels)", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + default=0.0, + help="probability of replacing a feature with 0", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + default="static", + help="how to choose mask length for channel masking", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + default=0, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + type=bool, + default=False, + help="whether to allow channel masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + default=1, + help="min space between spans (if no overlap is enabled)", + ) + + # loss computation + parser.add_argument( + "--skip-masked", + type=bool, + default=False, + help="skip computing losses over masked frames", + ) + + parser.add_argument( + "--skip-nomask", + type=bool, + default=False, + help="skip computing losses over unmasked frames", + ) + + parser.add_argument( + "--checkpoint-activations", + type=bool, + default=False, + help="recompute activations and save memory for extra compute", + ) + + parser.add_argument( + "--pred-masked-weight", + type=float, + default=1, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--pred-nomask-weight", + type=float, + default=0, + help="weight for masked part in ssl loss", + ) + + parser.add_argument( + "--loss-weights", + type=float, + nargs="*", + default=[10], + help="weight for masked part in ssl loss", + ) + + # FP16 optimization + parser.add_argument( + "--required-seq-len-multiple", + type=int, + default=2, + help="pad the input to encoder such that the sequence length is divisible by multiple", + ) + + parser.add_argument( + "--attn-type", type=str, default="", help="if espnet use ESPNET MHA" + ) + + parser.add_argument( + "--pos-enc-type", + type=str, + default="abs", + help="Positional encoding type to use in conformer", + ) + + parser.add_argument( + "--logit-temp", type=float, default=0.1, help="temperature to divide logits by" + ) + + parser.add_argument( + "--dropout-input", + type=float, + default=0.0, + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + default=0.0, + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--num-classes", + type=int, + nargs="*", + default=[504], + help="""num class, a little larger than the number of cluster, + the largest is for padding, + and the value should be the multiple of 4, for faster computation""", + ) + + parser.add_argument( + "--untie-final-proj", + type=bool, + default=False, + help="use separate projection for each target", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=400, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--warmup-batches", + type=float, + default=5000, + help="Eden warmup steps", + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0, + help="Eden warmup start learning rate", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--sanity-check", + type=str2bool, + default=False, + help="Check if any of the batches in epoch 1 would cause OOM.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=100000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accum-grad", + type=int, + default=4, + help="""update gradient when batch_idx_train % accum_grad == 0. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--max-keep-size", + type=int, + default=sys.maxsize, + help="exclude sample longer than this.", + ) + + parser.add_argument( + "--min-keep-size", + type=float, + default=32000, + help="exclude sample longer less than this.", + ) + + parser.add_argument( + "--max-sample-size", + type=float, + default=250000, + help="max sample size to crop to for batching.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of updates happen to the model so far across + epochs. + + - sub_batch_idx_train: It contains number of batch trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "sub_batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_model(params: AttributeDict) -> nn.Module: + model = HubertModel(params) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `dataset.HubertDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + audio = batch["audio"].to(device) + padding_mask = batch["padding_mask"].to(device) + kmeans = batch["kmeans"].to(device) + + with torch.set_grad_enabled(is_training): + loss, num_masked_tokens, logging_output = model( + source=audio, target_list=[kmeans], padding_mask=padding_mask + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = num_masked_tokens + for item in logging_output: + info[item] = logging_output[item] + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for sub_batch_idx, batch in enumerate(train_dl): + params.sub_batch_idx_train += 1 + batch_idx = sub_batch_idx // params.accum_grad + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + batch_size = batch["kmeans"].shape[0] + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss / params.accum_grad).backward() + + if sub_batch_idx % params.accum_grad == params.accum_grad - 1: + params.batch_idx_train += 1 + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if batch_idx % params.accum_grad != params.accum_grad - 1: + optimizer.zero_grad() + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden( + optimizer, + params.lr_batches, + params.lr_epochs, + params.warmup_batches, + params.warmup_start, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechDataModule(args) + + train_cuts = ( + librispeech.train_all_shuf_cuts() + if params.full_libri + else librispeech.train_clean_100_cuts() + ) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if ( + c.duration < params.min_keep_size / params.sample_rate + or c.duration > params.max_keep_size / params.sample_rate + ): + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + max_sample_size=params.max_sample_size, + sample_rate=params.sample_rate, + label_rate=params.label_rate, + random_crop=params.random_crop, + pad_audio=False, + num_classes=params.num_classes, + do_normalize=params.do_normalize, + sampler_state_dict=sampler_state_dict, + ) + + valid_cuts = librispeech.dev_clean_cuts() + # valid_cuts += librispeech.dev_other_cuts() + valid_cuts = valid_cuts.filter(remove_short_and_long_utt) + + valid_dl = librispeech.valid_dataloaders( + valid_cuts, + max_sample_size=params.max_sample_size, + sample_rate=params.sample_rate, + label_rate=params.label_rate, + random_crop=params.random_crop, + pad_audio=False, + num_classes=params.num_classes, + do_normalize=params.do_normalize, + ) + + if params.sanity_check and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `dataset.HubertDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + audio = batch["audio"] + logging.info(f"audio shape: {audio.shape}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/SSL/zipformer/scaling.py b/egs/librispeech/SSL/zipformer/scaling.py new file mode 120000 index 000000000..e30bd99de --- /dev/null +++ b/egs/librispeech/SSL/zipformer/scaling.py @@ -0,0 +1 @@ +../../ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/ssl_datamodule.py b/egs/librispeech/SSL/zipformer/ssl_datamodule.py new file mode 120000 index 000000000..9f5085e3a --- /dev/null +++ b/egs/librispeech/SSL/zipformer/ssl_datamodule.py @@ -0,0 +1 @@ +../hubert/ssl_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/SSL/zipformer/utils.py b/egs/librispeech/SSL/zipformer/utils.py new file mode 100644 index 000000000..748d3c96e --- /dev/null +++ b/egs/librispeech/SSL/zipformer/utils.py @@ -0,0 +1,337 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def relu_squared(x: torch.Tensor): + return F.relu(x).pow(2) + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == "xla" + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + +def pad_to_multiple(x, multiple, dim=-1, value=0): + # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 + if x is None: + return None, 0 + tsz = x.size(dim) + m = tsz / multiple + remainder = math.ceil(m) * multiple - tsz + if m.is_integer(): + return x, 0 + pad_offset = (0,) * (-1 - dim) * 2 + + return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str) -> Callable: + """Returns the activation function corresponding to `activation`""" + if activation == "relu": + return F.relu + elif activation == "relu_squared": + return relu_squared + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "swish": + return torch.nn.SiLU + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class SamePad2d(nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + assert len(x.size()) == 4 + if self.remove > 0: + x = x[:, :, : -self.remove, : -self.remove] + return x + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None, tranpose_dim=-2): + super().__init__() + self.deconstruct_idx = deconstruct_idx + self.tranpose_dim = tranpose_dim + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(self.tranpose_dim, -1) + + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + has_fused_layernorm = False + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + export = True + if not export and torch.cuda.is_available() and has_fused_layernorm: + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class FairseqDropout(nn.Module): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace: bool = False): + if self.p > 0 and (self.training or self.apply_during_inference): + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + def make_generation_fast_( + self, + name: str, + retain_dropout: bool = False, + retain_dropout_modules: Optional[List[str]] = None, + **kwargs + ): + if retain_dropout: + if retain_dropout_modules is not None and self.module_name is None: + pass + elif ( + retain_dropout_modules is None # if None, apply to all modules + or self.module_name in retain_dropout_modules + ): + self.apply_during_inference = True + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/egs/librispeech/SSL/zipformer/wav2vec2_module.py b/egs/librispeech/SSL/zipformer/wav2vec2_module.py new file mode 100644 index 000000000..ab5ca005f --- /dev/null +++ b/egs/librispeech/SSL/zipformer/wav2vec2_module.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py new file mode 100644 index 000000000..e9eff3357 --- /dev/null +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -0,0 +1,2438 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, + Balancer, + BiasNorm, + ChunkCausalDepthwiseConv1d, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + ) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + # x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + # assert self.output_downsampling_factor == 2, self.output_downsampling_factor + # if torch.jit.is_scripting() or torch.jit.is_tracing(): + # lengths = (x_lens + 1) // 2 + # else: + # with warnings.catch_warnings(): + # warnings.simplefilter("ignore") + # lengths = (x_lens + 1) // 2 + + # return x, lengths + return x, x_lens + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif not self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.bypass(src_orig, src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + output, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + output, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + return output, new_states + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, downsample, dropout) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Downsample, go through encoder, upsample, in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); + True means masked position. May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + src_orig = src + src = self.downsample(src) + + src, new_states = self.encoder.streaming_forward( + src, + states=states, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src), new_states + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + left_context_len + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + + c = Zipformer2( + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) From f2e36ec4148d77f7f6a0105b882797a1579374ed Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 9 Apr 2024 11:37:08 +0800 Subject: [PATCH 148/216] Zipformer recipe for CommonVoice (#1546) * added scripts for char-based lang prep training scripts * added `Zipformer` recipe for commonvoice --------- Co-authored-by: Fangjun Kuang --- egs/commonvoice/ASR/RESULTS.md | 89 +- .../local/compute_fbank_commonvoice_splits.py | 33 +- egs/commonvoice/ASR/local/prepare_char.py | 1 + egs/commonvoice/ASR/local/prepare_lang.py | 1 + egs/commonvoice/ASR/local/prepare_lang_fst.py | 1 + .../ASR/local/preprocess_commonvoice.py | 46 +- egs/commonvoice/ASR/local/word_segment_yue.py | 147 ++ egs/commonvoice/ASR/prepare.sh | 396 +++-- .../asr_datamodule.py | 16 + .../ASR/pruned_transducer_stateless7/train.py | 39 +- .../asr_datamodule.py | 1 + .../commonvoice_fr.py | 426 ----- .../decode.py | 7 +- .../do_not_use_it_directly.py | 7 +- .../finetune.py | 9 +- .../streaming_decode.py | 6 +- .../train.py | 41 +- .../ASR/zipformer/asr_datamodule.py | 1 + egs/commonvoice/ASR/zipformer/beam_search.py | 1 + egs/commonvoice/ASR/zipformer/decode.py | 1052 ++++++++++++ egs/commonvoice/ASR/zipformer/decode_char.py | 813 ++++++++++ .../ASR/zipformer/decode_stream.py | 1 + egs/commonvoice/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-ctc.py | 1 + .../zipformer/export-onnx-streaming-ctc.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/commonvoice/ASR/zipformer/export-onnx.py | 1 + egs/commonvoice/ASR/zipformer/export.py | 1 + egs/commonvoice/ASR/zipformer/joiner.py | 1 + egs/commonvoice/ASR/zipformer/model.py | 1 + egs/commonvoice/ASR/zipformer/onnx_check.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/commonvoice/ASR/zipformer/optim.py | 1 + egs/commonvoice/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 859 ++++++++++ .../ASR/zipformer/streaming_decode_char.py | 861 ++++++++++ egs/commonvoice/ASR/zipformer/subsampling.py | 1 + egs/commonvoice/ASR/zipformer/train.py | 1411 +++++++++++++++++ egs/commonvoice/ASR/zipformer/train_char.py | 1051 ++++++++++++ egs/commonvoice/ASR/zipformer/zipformer.py | 1 + 43 files changed, 6762 insertions(+), 571 deletions(-) create mode 120000 egs/commonvoice/ASR/local/prepare_char.py create mode 120000 egs/commonvoice/ASR/local/prepare_lang.py create mode 120000 egs/commonvoice/ASR/local/prepare_lang_fst.py create mode 100755 egs/commonvoice/ASR/local/word_segment_yue.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py delete mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py create mode 120000 egs/commonvoice/ASR/zipformer/asr_datamodule.py create mode 120000 egs/commonvoice/ASR/zipformer/beam_search.py create mode 100755 egs/commonvoice/ASR/zipformer/decode.py create mode 100755 egs/commonvoice/ASR/zipformer/decode_char.py create mode 120000 egs/commonvoice/ASR/zipformer/decode_stream.py create mode 120000 egs/commonvoice/ASR/zipformer/decoder.py create mode 120000 egs/commonvoice/ASR/zipformer/encoder_interface.py create mode 120000 egs/commonvoice/ASR/zipformer/export-onnx-ctc.py create mode 120000 egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py create mode 120000 egs/commonvoice/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/commonvoice/ASR/zipformer/export-onnx.py create mode 120000 egs/commonvoice/ASR/zipformer/export.py create mode 120000 egs/commonvoice/ASR/zipformer/joiner.py create mode 120000 egs/commonvoice/ASR/zipformer/model.py create mode 120000 egs/commonvoice/ASR/zipformer/onnx_check.py create mode 120000 egs/commonvoice/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/commonvoice/ASR/zipformer/optim.py create mode 120000 egs/commonvoice/ASR/zipformer/scaling.py create mode 120000 egs/commonvoice/ASR/zipformer/scaling_converter.py create mode 120000 egs/commonvoice/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/commonvoice/ASR/zipformer/streaming_decode.py create mode 100755 egs/commonvoice/ASR/zipformer/streaming_decode_char.py create mode 120000 egs/commonvoice/ASR/zipformer/subsampling.py create mode 100755 egs/commonvoice/ASR/zipformer/train.py create mode 100755 egs/commonvoice/ASR/zipformer/train_char.py create mode 120000 egs/commonvoice/ASR/zipformer/zipformer.py diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md index 2c158d91d..f384f66a0 100644 --- a/egs/commonvoice/ASR/RESULTS.md +++ b/egs/commonvoice/ASR/RESULTS.md @@ -1,20 +1,91 @@ ## Results -### GigaSpeech BPE training results (Pruned Stateless Transducer 7) + +### Commonvoice Cantonese (zh-HK) Char training results (Zipformer) + +See #1546 for more details. + +Number of model parameters: 72526519, i.e., 72.53 M + +The best CER, for CommonVoice 16.1 (cv-corpus-16.1-2023-12-06/zh-HK) is below: + +| | Dev | Test | Note | +|----------------------|-------|------|--------------------| +| greedy_search | 1.17 | 1.22 | --epoch 24 --avg 5 | +| modified_beam_search | 0.98 | 1.11 | --epoch 24 --avg 5 | +| fast_beam_search | 1.08 | 1.27 | --epoch 24 --avg 5 | + +When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (w/o blank penalty), +the best CER is below: + +| | Dev | Test | Note | +|----------------------|-------|------|--------------------| +| greedy_search | 42.40 | 42.03| --epoch 24 --avg 5 | +| modified_beam_search | 39.73 | 39.19| --epoch 24 --avg 5 | +| fast_beam_search | 42.14 | 41.98| --epoch 24 --avg 5 | + +When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (with blank penalty set to 2.2), +the best CER is below: + +| | Dev | Test | Note | +|----------------------|-------|------|----------------------------------------| +| greedy_search | 39.19 | 39.09| --epoch 24 --avg 5 --blank-penalty 2.2 | +| modified_beam_search | 37.73 | 37.65| --epoch 24 --avg 5 --blank-penalty 2.2 | +| fast_beam_search | 37.73 | 37.74| --epoch 24 --avg 5 --blank-penalty 2.2 | + +To reproduce the above result, use the following commands for training: + +```bash +export CUDA_VISIBLE_DEVICES="0,1" +./zipformer/train_char.py \ + --world-size 2 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --cv-manifest-dir data/zh-HK/fbank \ + --language zh-HK \ + --use-validated-set 1 \ + --context-size 1 \ + --max-duration 1000 +``` + +and the following commands for decoding: + +```bash +for method in greedy_search modified_beam_search fast_beam_search; do + ./zipformer/decode_char.py \ + --epoch 24 \ + --avg 5 \ + --decoding-method $method \ + --exp-dir zipformer/exp \ + --cv-manifest-dir data/zh-HK/fbank \ + --context-size 1 \ + --language zh-HK +done +``` + +Detailed experimental results and pre-trained model are available at: + + + +### CommonVoice English (en) BPE training results (Pruned Stateless Transducer 7) #### [pruned_transducer_stateless7](./pruned_transducer_stateless7) -See #997 for more details. +See #997 for more details. Number of model parameters: 70369391, i.e., 70.37 M +Note that the result is obtained using GigaSpeech transcript trained BPE model + The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below: Results are: | | Dev | Test | |----------------------|-------|-------| -| greedy search | 9.96 | 12.54 | -| modified beam search | 9.86 | 12.48 | +| greedy_search | 9.96 | 12.54 | +| modified_beam_search | 9.86 | 12.48 | To reproduce the above result, use the following commands for training: @@ -55,10 +126,6 @@ and the following commands for decoding: Pretrained model is available at -The tensorboard log for training is available at - - - ### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming) #### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) @@ -73,9 +140,9 @@ Results are: | decoding method | Test | |----------------------|-------| -| greedy search | 9.95 | -| modified beam search | 9.57 | -| fast beam search | 9.67 | +| greedy_search | 9.95 | +| modified_beam_search | 9.57 | +| fast_beam_search | 9.67 | Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice. diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py index f31b45aa5..aa672609a 100755 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (Yifan Yang) +# Copyright 2023-2024 Xiaomi Corp. (Yifan Yang, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -17,7 +18,6 @@ import argparse import logging -from datetime import datetime from pathlib import Path import torch @@ -30,6 +30,8 @@ from lhotse import ( set_caching_enabled, ) +from icefall.utils import str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -41,6 +43,14 @@ torch.set_num_interop_threads(1) def get_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--subset", + type=str, + default="train", + choices=["train", "validated", "invalidated"], + help="""Dataset parts to compute fbank. """, + ) + parser.add_argument( "--language", type=str, @@ -66,28 +76,35 @@ def get_args(): "--num-splits", type=int, required=True, - help="The number of splits of the train subset", + help="The number of splits of the subset", ) parser.add_argument( "--start", type=int, default=0, - help="Process pieces starting from this number (inclusive).", + help="Process pieces starting from this number (included).", ) parser.add_argument( "--stop", type=int, default=-1, - help="Stop processing pieces until this number (exclusive).", + help="Stop processing pieces until this number (excluded).", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", ) return parser.parse_args() def compute_fbank_commonvoice_splits(args): - subset = "train" + subset = args.subset num_splits = args.num_splits language = args.language output_dir = f"data/{language}/fbank/cv-{language}_{subset}_split_{num_splits}" @@ -130,6 +147,10 @@ def compute_fbank_commonvoice_splits(args): keep_overlapping=False, min_duration=None ) + if args.perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + logging.info("Computing features") cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, diff --git a/egs/commonvoice/ASR/local/prepare_char.py b/egs/commonvoice/ASR/local/prepare_char.py new file mode 120000 index 000000000..42743b544 --- /dev/null +++ b/egs/commonvoice/ASR/local/prepare_char.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/prepare_lang.py b/egs/commonvoice/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/commonvoice/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/prepare_lang_fst.py b/egs/commonvoice/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/commonvoice/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index dbacdd821..cc88ef8d7 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -21,7 +21,7 @@ import re from pathlib import Path from typing import Optional -from lhotse import CutSet, SupervisionSegment +from lhotse import CutSet from lhotse.recipes.utils import read_manifests_if_cached @@ -52,14 +52,20 @@ def normalize_text(utt: str, language: str) -> str: return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper() elif language == "pl": return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper() - elif language == "yue": - return ( - utt.replace(" ", "") - .replace(",", "") - .replace("。", " ") - .replace("?", "") - .replace("!", "") - .replace("?", "") + elif language in ["yue", "zh-HK"]: + # Mozilla Common Voice uses both "yue" and "zh-HK" for Cantonese + # Not sure why they decided to do this... + # None en/zh-yue tokens are manually removed here + + # fmt: off + tokens_to_remove = [",", "。", "?", "!", "?", "!", "‘", "、", ",", "\.", ":", ";", "「", "」", "“", "”", "~", "—", "ㄧ", "《", "》", "…", "⋯", "·", "﹒", ".", ":", "︰", "﹖", "(", ")", "-", "~", ";", "", "⠀", "﹔", "/", "A", "B", "–", "‧"] + + # fmt: on + utt = utt.upper().replace("\\", "") + return re.sub( + pattern="|".join([f"[{token}]" for token in tokens_to_remove]), + repl="", + string=utt, ) else: raise NotImplementedError( @@ -130,6 +136,28 @@ def preprocess_commonvoice( supervisions=m["supervisions"], ).resample(16000) + if partition == "validated": + logging.warning( + """ + The 'validated' partition contains the data of both 'train', 'dev' + and 'test' partitions. We filter out the 'dev' and 'test' partition + here. + """ + ) + dev_ids = src_dir / f"cv-{language}_dev_ids" + test_ids = src_dir / f"cv-{language}_test_ids" + assert ( + dev_ids.is_file() + ), f"{dev_ids} does not exist, please check stage 1 of the prepare.sh" + assert ( + test_ids.is_file() + ), f"{test_ids} does not exist, please check stage 1 of the prepare.sh" + dev_ids = dev_ids.read_text().strip().split("\n") + test_ids = test_ids.read_text().strip().split("\n") + cut_set = cut_set.filter( + lambda x: x.supervisions[0].id not in dev_ids + test_ids + ) + # Run data augmentation that needs to be done in the # time domain. logging.info(f"Saving to {raw_cuts_path}") diff --git a/egs/commonvoice/ASR/local/word_segment_yue.py b/egs/commonvoice/ASR/local/word_segment_yue.py new file mode 100755 index 000000000..35d262d10 --- /dev/null +++ b/egs/commonvoice/ASR/local/word_segment_yue.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a text file "data/lang_char/text" as input, the file consist of +lines each containing a transcript, applies text norm and generates the following +files in the directory "data/lang_char": + - transcript_words.txt + - words.txt + - words_no_ids.txt +""" + +import argparse +import logging +import re +from pathlib import Path +from typing import List + +import pycantonese +from preprocess_commonvoice import normalize_text +from tqdm.auto import tqdm + +from icefall.utils import is_cjk, tokenize_by_CJK_char + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Prepare char lexicon", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input-file", + "-i", + default="data/yue/lang_char/text", + type=str, + help="The input text file", + ) + parser.add_argument( + "--output-dir", + "-o", + default="data/yue/lang_char/", + type=str, + help="The output directory", + ) + parser.add_argument( + "--lang", + "-l", + default="yue", + type=str, + help="The language", + ) + return parser + + +def get_word_segments(lines: List[str]) -> List[str]: + # the current pycantonese segmenter does not handle the case when the input + # is code switching, so we need to handle it separately + + new_lines = [] + + for line in tqdm(lines, desc="Segmenting lines"): + try: + if is_cs(line): # code switching + segments = [] + curr_str = "" + for segment in tokenize_by_CJK_char(line).split(" "): + if segment.strip() == "": + continue + try: + if not is_cjk(segment[0]): # en segment + if curr_str: + segments.extend(pycantonese.segment(curr_str)) + curr_str = "" + segments.append(segment) + else: # zh segment + curr_str += segment + # segments.extend(pycantonese.segment(segment)) + except Exception as e: + logging.error(f"Failed to process segment: {segment}") + raise + if curr_str: # process the last segment + segments.extend(pycantonese.segment(curr_str)) + new_lines.append(" ".join(segments) + "\n") + else: # not code switching + new_lines.append(" ".join(pycantonese.segment(line)) + "\n") + except Exception as e: + logging.error(f"Failed to process line: {line}") + raise e + return new_lines + + +def get_words(lines: List[str]) -> List[str]: + words = set() + for line in tqdm(lines, desc="Getting words"): + words.update(line.strip().split(" ")) + return list(words) + + +def is_cs(line: str) -> bool: + english_markers = r"[a-zA-Z]+" + return bool(re.search(english_markers, line)) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + + input_file = Path(args.input_file) + output_dir = Path(args.output_dir) + lang = args.lang + + assert input_file.is_file(), f"{input_file} does not exist" + assert output_dir.is_dir(), f"{output_dir} does not exist" + + lines = input_file.read_text(encoding="utf-8").strip().split("\n") + norm_lines = [normalize_text(line, lang) for line in lines] + + text_words_segments = get_word_segments(norm_lines) + with open(output_dir / "transcript_words.txt", "w", encoding="utf-8") as f: + f.writelines(text_words_segments) + + words = get_words(text_words_segments)[1:] # remove "\n" from words + with open(output_dir / "words_no_ids.txt", "w", encoding="utf-8") as f: + f.writelines([word + "\n" for word in sorted(words)]) + + words = ( + ["", "!SIL", "", ""] + + sorted(words) + + ["#0", "", "<\s>"] + ) + + with open(output_dir / "words.txt", "w", encoding="utf-8") as f: + f.writelines([f"{word} {i}\n" for i, word in enumerate(words)]) diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh index edac0e8e6..4e76ef041 100755 --- a/egs/commonvoice/ASR/prepare.sh +++ b/egs/commonvoice/ASR/prepare.sh @@ -10,6 +10,12 @@ stop_stage=100 # This is to avoid OOM during feature extraction. num_splits=1000 +# In case you want to use all validated data +use_validated=false + +# In case you are willing to take the risk and use invalidated data +use_invalidated=false + # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded # by this script automatically. @@ -38,6 +44,7 @@ num_splits=1000 dl_dir=$PWD/download release=cv-corpus-12.0-2022-12-07 lang=fr +perturb_speed=false . shared/parse_options.sh || exit 1 @@ -100,8 +107,40 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then mkdir -p data/${lang}/manifests if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests + + if [ $use_validated = true ] && [ ! -f data/${lang}/manifests/.cv-${lang}.validated.done ]; then + log "Also prepare validated data" + lhotse prepare commonvoice \ + --split validated \ + --language $lang \ + -j $nj $dl_dir/$release data/${lang}/manifests + touch data/${lang}/manifests/.cv-${lang}.validated.done + fi + + if [ $use_invalidated = true ] && [ ! -f data/${lang}/manifests/.cv-${lang}.invalidated.done ]; then + log "Also prepare invalidated data" + lhotse prepare commonvoice \ + --split invalidated \ + --language $lang \ + -j $nj $dl_dir/$release data/${lang}/manifests + touch data/${lang}/manifests/.cv-${lang}.invalidated.done + fi + touch data/${lang}/manifests/.cv-${lang}.done fi + + # Note: in Linux, you can install jq with the following command: + # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + # 2. chmod +x ./jq + # 3. cp jq /usr/bin + if [ $use_validated = true ]; then + log "Getting cut ids from dev/test sets for later use" + gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_test.jsonl.gz \ + | jq '.id' | sed 's/"//g' > data/${lang}/manifests/cv-${lang}_test_ids + + gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_dev.jsonl.gz \ + | jq '.id' | sed 's/"//g' > data/${lang}/manifests/cv-${lang}_dev_ids + fi fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -121,6 +160,18 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then ./local/preprocess_commonvoice.py --language $lang touch data/${lang}/fbank/.preprocess_complete fi + + if [ $use_validated = true ] && [ ! -f data/${lang}/fbank/.validated.preprocess_complete ]; then + log "Also preprocess validated data" + ./local/preprocess_commonvoice.py --language $lang --dataset validated + touch data/${lang}/fbank/.validated.preprocess_complete + fi + + if [ $use_invalidated = true ] && [ ! -f data/${lang}/fbank/.invalidated.preprocess_complete ]; then + log "Also preprocess invalidated data" + ./local/preprocess_commonvoice.py --language $lang --dataset invalidated + touch data/${lang}/fbank/.invalidated.preprocess_complete + fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then @@ -139,6 +190,20 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir touch $split_dir/.cv-${lang}_train_split.done fi + + split_dir=data/${lang}/fbank/cv-${lang}_validated_split_${num_splits} + if [ $use_validated = true ] && [ ! -f $split_dir/.cv-${lang}_validated.done ]; then + log "Also split validated data" + lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_validated_raw.jsonl.gz $split_dir + touch $split_dir/.cv-${lang}_validated.done + fi + + split_dir=data/${lang}/fbank/cv-${lang}_invalidated_split_${num_splits} + if [ $use_invalidated = true ] && [ ! -f $split_dir/.cv-${lang}_invalidated.done ]; then + log "Also split invalidated data" + lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_invalidated_raw.jsonl.gz $split_dir + touch $split_dir/.cv-${lang}_invalidated.done + fi fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then @@ -149,9 +214,36 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then --batch-duration 200 \ --start 0 \ --num-splits $num_splits \ - --language $lang + --language $lang \ + --perturb-speed $perturb_speed touch data/${lang}/fbank/.cv-${lang}_train.done fi + + if [ $use_validated = true ] && [ ! -f data/${lang}/fbank/.cv-${lang}_validated.done ]; then + log "Also compute features for validated data" + ./local/compute_fbank_commonvoice_splits.py \ + --subset validated \ + --num-workers $nj \ + --batch-duration 200 \ + --start 0 \ + --num-splits $num_splits \ + --language $lang \ + --perturb-speed $perturb_speed + touch data/${lang}/fbank/.cv-${lang}_validated.done + fi + + if [ $use_invalidated = true ] && [ ! -f data/${lang}/fbank/.cv-${lang}_invalidated.done ]; then + log "Also compute features for invalidated data" + ./local/compute_fbank_commonvoice_splits.py \ + --subset invalidated \ + --num-workers $nj \ + --batch-duration 200 \ + --start 0 \ + --num-splits $num_splits \ + --language $lang \ + --perturb-speed $perturb_speed + touch data/${lang}/fbank/.cv-${lang}_invalidated.done + fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then @@ -160,6 +252,20 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz") lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz fi + + if [ $use_validated = true ] && [ -f data/${lang}/fbank/.cv-${lang}_validated.done ]; then + log "Also combine features for validated data" + pieces=$(find data/${lang}/fbank/cv-${lang}_validated_split_${num_splits} -name "cv-${lang}_cuts_validated.*.jsonl.gz") + lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_validated.jsonl.gz + touch data/${lang}/fbank/.cv-${lang}_validated.done + fi + + if [ $use_invalidated = true ] && [ -f data/${lang}/fbank/.cv-${lang}_invalidated.done ]; then + log "Also combine features for invalidated data" + pieces=$(find data/${lang}/fbank/cv-${lang}_invalidated_split_${num_splits} -name "cv-${lang}_cuts_invalidated.*.jsonl.gz") + lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_invalidated.jsonl.gz + touch data/${lang}/fbank/.cv-${lang}_invalidated.done + fi fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then @@ -172,83 +278,134 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/${lang}/lang_bpe_${vocab_size} + if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then + log "Stage 9: Prepare Char based lang" + lang_dir=data/${lang}/lang_char/ mkdir -p $lang_dir if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - file=$( - find "data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz" - ) - gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt + log "Generate data for lang preparation" - # Ensure space only appears once - sed -i 's/\t/ /g' $lang_dir/transcript_words.txt - sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt - fi + # Prepare text. + # Note: in Linux, you can install jq with the following command: + # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + # 2. chmod +x ./jq + # 3. cp jq /usr/bin + if [ $use_validated = true ]; then + gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_validated.jsonl.gz \ + | jq '.text' | sed 's/"//g' >> $lang_dir/text + else + gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_train.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $lang_dir/text + fi + + if [ $use_invalidated = true ]; then + gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_invalidated.jsonl.gz \ + | jq '.text' | sed 's/"//g' >> $lang_dir/text + fi - if [ ! -f $lang_dir/words.txt ]; then - cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_dir/words.txt - (echo '!SIL'; echo ''; echo ''; ) | - cat - $lang_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; + if [ $lang == "yue" ] || [ $lang == "zh-HK" ]; then + # Get words.txt and words_no_ids.txt + ./local/word_segment_yue.py \ + --input-file $lang_dir/text \ + --output-dir $lang_dir \ + --lang $lang + + mv $lang_dir/text $lang_dir/_text + cp $lang_dir/transcript_words.txt $lang_dir/text + + if [ ! -f $lang_dir/tokens.txt ]; then + ./local/prepare_char.py --lang-dir $lang_dir + fi + else + log "word_segment_${lang}.py not implemented yet" + exit 1 + fi + fi + else + log "Stage 9: Prepare BPE based lang" + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + file=$( + find "data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz" + ) + # Prepare text. + # Note: in Linux, you can install jq with the following command: + # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + # 2. chmod +x ./jq + # 3. cp jq /usr/bin + gunzip -c ${file} \ + | jq '.text' | sed 's/"//g' > $lang_dir/transcript_words.txt + + # Ensure space only appears once + sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/words.txt ]; then + cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' > $lang_dir/words.txt + (echo '!SIL'; echo ''; echo ''; ) | + cat - $lang_dir/words.txt | sort | uniq | awk ' + BEGIN { + print " 0"; } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_dir/words || exit 1; - mv $lang_dir/words $lang_dir/words.txt - fi + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $lang_dir/words || exit 1; + mv $lang_dir/words $lang_dir/words.txt + fi - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done + fi fi if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then @@ -256,49 +413,96 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then # We assume you have install kaldilm, if not, please install # it using: pip install kaldilm - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/${lang}/lang_bpe_${vocab_size} + if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then + lang_dir=data/${lang}/lang_char mkdir -p $lang_dir/lm - #3-gram used in building HLG, 4-gram used for LM rescoring - for ngram in 3 4; do - if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order ${ngram} \ - -text $lang_dir/transcript_words.txt \ - -lm $lang_dir/lm/${ngram}gram.arpa - fi - if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/words.txt" \ - --disambig-symbol='#0' \ - --max-order=${ngram} \ - $lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt - fi + for ngram in 3 ; do + if [ ! -f $lang_dir/lm/${ngram}-gram.unpruned.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_words.txt \ + -lm $lang_dir/lm/${ngram}gram.unpruned.arpa + fi + + if [ ! -f $lang_dir/lm/G_${ngram}_gram_char.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/lm/${ngram}gram.unpruned.arpa \ + > $lang_dir/lm/G_${ngram}_gram_char.fst.txt + fi + + if [ ! -f $lang_dir/lm/HLG.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_dir \ + --ngram-G $lang_dir/lm/G_${ngram}_gram_char.fst.txt + fi + done + else + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + mkdir -p $lang_dir/lm + #3-gram used in building HLG, 4-gram used for LM rescoring + for ngram in 3 4; do + if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_words.txt \ + -lm $lang_dir/lm/${ngram}gram.arpa + fi + + if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt + fi + done done - done + fi fi if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then log "Stage 11: Compile HLG" - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/${lang}/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir + if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then + lang_dir=data/${lang}/lang_char + for ngram in 3 ; do + if [ ! -f $lang_dir/lm/HLG_${ngram}.fst ]; then + ./local/compile_hlg.py --lang-dir $lang_dir --lm G_${ngram}_gram_char + fi + done + else + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir - done + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done + fi fi # Compile LG for RNN-T fast_beam_search decoding if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then log "Stage 12: Compile LG" - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/${lang}/lang_bpe_${vocab_size} - ./local/compile_lg.py --lang-dir $lang_dir - done + if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then + lang_dir=data/${lang}/lang_char + for ngram in 3 ; do + if [ ! -f $lang_dir/lm/LG_${ngram}.fst ]; then + ./local/compile_lg.py --lang-dir $lang_dir --lm G_${ngram}_gram_char + fi + done + else + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done + fi fi diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index 41009831c..a80cfe85e 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -409,6 +409,22 @@ class CommonVoiceAsrDataModule: self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz" ) + @lru_cache() + def validated_cuts(self) -> CutSet: + logging.info("About to get validated cuts (with dev/test removed)") + return load_manifest_lazy( + self.args.cv_manifest_dir + / f"cv-{self.args.language}_cuts_validated.jsonl.gz" + ) + + @lru_cache() + def invalidated_cuts(self) -> CutSet: + logging.info("About to get invalidated cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir + / f"cv-{self.args.language}_cuts_invalidated.jsonl.gz" + ) + @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 4957c0c31..5e98084ec 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -249,7 +250,29 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." + "--use-validated-set", + type=str2bool, + default=False, + help="""Use the validated set for training. + This is useful when you want to use more data for training, + but not recommended for research purposes. + """, + ) + + parser.add_argument( + "--use-invalidated-set", + type=str2bool, + default=False, + help="""Use the invalidated set for training. + In case you want to take the risk and utilize more data for training. + """, + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.05, + help="The base learning rate.", ) parser.add_argument( @@ -1027,7 +1050,13 @@ def run(rank, world_size, args): commonvoice = CommonVoiceAsrDataModule(args) - train_cuts = commonvoice.train_cuts() + if args.use_validated_set: + train_cuts = commonvoice.validated_cuts() + else: + train_cuts = commonvoice.train_cuts() + + if args.use_invalidated_set: + train_cuts += commonvoice.invalidated_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..c274de28a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/asr_datamodule.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py deleted file mode 100644 index 91220bd11..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class CommonVoiceAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--language", - type=str, - default="fr", - help="""Language of Common Voice""", - ) - group.add_argument( - "--cv-manifest-dir", - type=Path, - default=Path("data/fr/fbank"), - help="Path to directory with CommonVoice train/dev/test cuts.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)() - ), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz" - ) - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz" - ) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py index 30f7c1e77..7ae4f1894 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -112,6 +113,7 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -122,7 +124,6 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from commonvoice_fr import CommonVoiceAsrDataModule from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index a3f387636..aefe88f3f 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo,) -# Zengwei Yao) +# Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -55,7 +56,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from commonvoice_fr import CommonVoiceAsrDataModule +from asr_datamodule import CommonVoiceAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py index 81c69e5e0..976004eca 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -58,7 +59,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from commonvoice_fr import CommonVoiceAsrDataModule +from asr_datamodule import CommonVoiceAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 018736d26..bb1c093c8 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -37,7 +39,7 @@ import numpy as np import sentencepiece as spm import torch import torch.nn as nn -from commonvoice_fr import CommonVoiceAsrDataModule +from asr_datamodule import CommonVoiceAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index 728104580..67e1a8133 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -55,7 +56,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from commonvoice_fr import CommonVoiceAsrDataModule +from asr_datamodule import CommonVoiceAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -265,7 +266,29 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." + "--use-validated-set", + type=str2bool, + default=False, + help="""Use the validated set for training. + This is useful when you want to use more data for training, + but not recommended for research purposes. + """, + ) + + parser.add_argument( + "--use-invalidated-set", + type=str2bool, + default=False, + help="""Use the invalidated set for training. + In case you want to take the risk and utilize more data for training. + """, + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.05, + help="The base learning rate.", ) parser.add_argument( @@ -1044,7 +1067,13 @@ def run(rank, world_size, args): commonvoice = CommonVoiceAsrDataModule(args) - train_cuts = commonvoice.train_cuts() + if not args.use_validated_set: + train_cuts = commonvoice.train_cuts() + else: + train_cuts = commonvoice.validated_cuts() + + if args.use_invalidated_set: + train_cuts += commonvoice.invalidated_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/commonvoice/ASR/zipformer/asr_datamodule.py b/egs/commonvoice/ASR/zipformer/asr_datamodule.py new file mode 120000 index 000000000..c274de28a --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/asr_datamodule.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/beam_search.py b/egs/commonvoice/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/decode.py b/egs/commonvoice/ASR/zipformer/decode.py new file mode 100755 index 000000000..7fd6d0ccd --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/decode.py @@ -0,0 +1,1052 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + commonvoice = CommonVoiceAsrDataModule(args) + + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + + test_dl = commonvoice.test_dataloaders(test_cuts) + dev_dl = commonvoice.valid_dataloaders(dev_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/zipformer/decode_char.py b/egs/commonvoice/ASR/zipformer/decode_char.py new file mode 100755 index 000000000..1f8c9c7c6 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/decode_char.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao +# Mingshuang Luo, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/zh-HK/lang_char \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/zh-HK/lang_char \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/zh-HK/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/zh-HK/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/zh-HK/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/zh-HK/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + commonvoice = CommonVoiceAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = commonvoice.dev_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = commonvoice.valid_dataloaders(dev_cuts) + + test_cuts = commonvoice.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = commonvoice.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/zipformer/decode_stream.py b/egs/commonvoice/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/decoder.py b/egs/commonvoice/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/encoder_interface.py b/egs/commonvoice/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py b/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 000000000..f9d756352 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 120000 index 000000000..652346001 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py b/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export-onnx.py b/egs/commonvoice/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export.py b/egs/commonvoice/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/joiner.py b/egs/commonvoice/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/model.py b/egs/commonvoice/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/onnx_check.py b/egs/commonvoice/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/onnx_pretrained.py b/egs/commonvoice/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/optim.py b/egs/commonvoice/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/scaling.py b/egs/commonvoice/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/scaling_converter.py b/egs/commonvoice/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/streaming_beam_search.py b/egs/commonvoice/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/streaming_decode.py b/egs/commonvoice/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..1d0230c76 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/streaming_decode.py @@ -0,0 +1,859 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import CommonVoiceAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + commonvoice = CommonVoiceAsrDataModule(args) + + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + + test_sets = ["test", "dev"] + test_cuts = [test_cuts, dev_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/zipformer/streaming_decode_char.py b/egs/commonvoice/ASR/zipformer/streaming_decode_char.py new file mode 100755 index 000000000..249cba9f5 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/streaming_decode_char.py @@ -0,0 +1,861 @@ +#!/usr/bin/env python3 +# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +from asr_datamodule import CommonVoiceAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/zh-HK/lang_char", + help="Path to the lang dir(containing lexicon, tokens, etc.)", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + commonvoice = CommonVoiceAsrDataModule(args) + + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + + test_sets = ["test", "dev"] + test_cuts = [test_cuts, dev_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/zipformer/subsampling.py b/egs/commonvoice/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py new file mode 100755 index 000000000..5cda9bfd4 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -0,0 +1,1411 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--use-validated-set", + type=str2bool, + default=False, + help="""Use the validated set for training. + This is useful when you want to use more data for training, + but not recommended for research purposes. + """, + ) + + parser.add_argument( + "--use-invalidated-set", + type=str2bool, + default=False, + help="""Use the invalidated set for training. + In case you want to take the risk and utilize more data for training. + """, + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.045, + help="The base learning rate.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + if not args.use_validated_set: + train_cuts = commonvoice.train_cuts() + else: + train_cuts = commonvoice.validated_cuts() + + if args.use_invalidated_set: + train_cuts += commonvoice.invalidated_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + dev_cuts = commonvoice.dev_cuts() + dev_dl = commonvoice.valid_dataloaders(dev_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=dev_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py new file mode 100755 index 000000000..a780bbbbc --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -0,0 +1,1051 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from typing import Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from train import ( + add_model_arguments, + get_adjusted_batch_count, + get_model, + load_checkpoint_if_available, + save_checkpoint, + set_batch_count, +) + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/zh-HK/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--use-validated-set", + type=str2bool, + default=False, + help="""Use the validated set for training. + This is useful when you want to use more data for training, + but not recommended for research purposes. + """, + ) + + parser.add_argument( + "--use-invalidated-set", + type=str2bool, + default=False, + help="""Use the invalidated set for training. + In case you want to take the risk and utilize more data for training. + """, + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.045, + help="The base learning rate.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = supervisions["text"] + y = graph_compiler.texts_to_ids(texts) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + if not args.use_validated_set: + train_cuts = commonvoice.train_cuts() + else: + train_cuts = commonvoice.validated_cuts() + + if args.use_invalidated_set: + train_cuts += commonvoice.invalidated_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + dev_cuts = commonvoice.dev_cuts() + dev_dl = commonvoice.valid_dataloaders(dev_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=dev_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +# torch.set_num_threads(1) +# torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/zipformer/zipformer.py b/egs/commonvoice/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/commonvoice/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 1732dafe2412cccf85b1c2972d415bf50ca9e3b5 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 9 Apr 2024 12:06:14 +0800 Subject: [PATCH 149/216] Add zipformer recipe for audio tagging (#1421) --- .pre-commit-config.yaml | 2 +- egs/audioset/AT/README.md | 12 + egs/audioset/AT/RESULTS.md | 44 + egs/audioset/AT/local/compute_fbank_musan.py | 1 + .../AT/local/generate_audioset_manifest.py | 177 +++ egs/audioset/AT/prepare.sh | 104 ++ egs/audioset/AT/shared | 1 + egs/audioset/AT/zipformer/at_datamodule.py | 420 ++++++ .../AT/zipformer/encoder_interface.py | 1 + egs/audioset/AT/zipformer/evaluate.py | 344 +++++ egs/audioset/AT/zipformer/export-onnx.py | 411 ++++++ egs/audioset/AT/zipformer/export.py | 339 +++++ egs/audioset/AT/zipformer/jit_pretrained.py | 181 +++ egs/audioset/AT/zipformer/model.py | 157 +++ egs/audioset/AT/zipformer/onnx_pretrained.py | 250 ++++ egs/audioset/AT/zipformer/optim.py | 1 + egs/audioset/AT/zipformer/pretrained.py | 202 +++ egs/audioset/AT/zipformer/scaling.py | 1 + .../AT/zipformer/scaling_converter.py | 1 + egs/audioset/AT/zipformer/subsampling.py | 1 + egs/audioset/AT/zipformer/train.py | 1186 +++++++++++++++++ egs/audioset/AT/zipformer/zipformer.py | 1 + 22 files changed, 3836 insertions(+), 1 deletion(-) create mode 100644 egs/audioset/AT/README.md create mode 100644 egs/audioset/AT/RESULTS.md create mode 120000 egs/audioset/AT/local/compute_fbank_musan.py create mode 100644 egs/audioset/AT/local/generate_audioset_manifest.py create mode 100755 egs/audioset/AT/prepare.sh create mode 120000 egs/audioset/AT/shared create mode 100644 egs/audioset/AT/zipformer/at_datamodule.py create mode 120000 egs/audioset/AT/zipformer/encoder_interface.py create mode 100644 egs/audioset/AT/zipformer/evaluate.py create mode 100755 egs/audioset/AT/zipformer/export-onnx.py create mode 100755 egs/audioset/AT/zipformer/export.py create mode 100755 egs/audioset/AT/zipformer/jit_pretrained.py create mode 100644 egs/audioset/AT/zipformer/model.py create mode 100755 egs/audioset/AT/zipformer/onnx_pretrained.py create mode 120000 egs/audioset/AT/zipformer/optim.py create mode 100755 egs/audioset/AT/zipformer/pretrained.py create mode 120000 egs/audioset/AT/zipformer/scaling.py create mode 120000 egs/audioset/AT/zipformer/scaling_converter.py create mode 120000 egs/audioset/AT/zipformer/subsampling.py create mode 100644 egs/audioset/AT/zipformer/train.py create mode 120000 egs/audioset/AT/zipformer/zipformer.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5cb213327..70068f9cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: # E121,E123,E126,E226,E24,E704,W503,W504 - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile=black"] diff --git a/egs/audioset/AT/README.md b/egs/audioset/AT/README.md new file mode 100644 index 000000000..2506d41e5 --- /dev/null +++ b/egs/audioset/AT/README.md @@ -0,0 +1,12 @@ +# Introduction + +This is an audio tagging recipe for [Audioset](https://research.google.com/audioset/#/). It aims at predicting the sound events of an audio clip. + +[./RESULTS.md](./RESULTS.md) contains the latest results. + + +# Zipformer + +| Encoder | Feature type | +| --------| -------------| +| Zipformer | Frame level fbank| diff --git a/egs/audioset/AT/RESULTS.md b/egs/audioset/AT/RESULTS.md new file mode 100644 index 000000000..0c75dfe4e --- /dev/null +++ b/egs/audioset/AT/RESULTS.md @@ -0,0 +1,44 @@ +## Results + +### zipformer +See for more details + +[zipformer](./zipformer) + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +The model achieves the following mean averaged precision on AudioSet: + +| Model | mAP | +| ------ | ------- | +| Zipformer-AT | 45.1 | + +The training command is: + +```bash +export CUDA_VISIBLE_DEVICES="4,5,6,7" +subset=full + +python zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --exp-dir zipformer/exp_at_as_${subset} \ + --start-epoch 1 \ + --use-fp16 1 \ + --num-events 527 \ + --audioset-subset $subset \ + --max-duration 1000 \ + --enable-musan True \ + --master-port 13455 +``` + +The evaluation command is: + +```bash +python zipformer/evaluate.py \ + --epoch 32 \ + --avg 8 \ + --exp-dir zipformer/exp_at_as_full \ + --max-duration 500 +``` diff --git a/egs/audioset/AT/local/compute_fbank_musan.py b/egs/audioset/AT/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/audioset/AT/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/audioset/AT/local/generate_audioset_manifest.py b/egs/audioset/AT/local/generate_audioset_manifest.py new file mode 100644 index 000000000..1c5b3457c --- /dev/null +++ b/egs/audioset/AT/local/generate_audioset_manifest.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file generates the manifest and computes the fbank features for AudioSet +dataset. The generated manifests and features are stored in data/fbank. +""" + +import argparse +import csv +import glob +import logging +import os +from typing import Dict + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.audio import Recording +from lhotse.cut import MonoCut +from lhotse.supervision import SupervisionSegment + +from icefall.utils import get_executor + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_ID_mapping(csv_file): + # get a mapping between class ID and class name + mapping = {} + with open(csv_file, "r") as fin: + reader = csv.reader(fin, delimiter=",") + for i, row in enumerate(reader): + if i == 0: + continue + mapping[row[1]] = row[0] + return mapping + + +def parse_csv(csv_file: str, id_mapping: Dict): + # The content of the csv file shoud be something like this + # ------------------------------------------------------ + # filename label + # dataset/AudioSet/balanced/xxxx.wav 0;451 + # dataset/AudioSet/balanced/xxxy.wav 375 + # ------------------------------------------------------ + + def name2id(names): + ids = [id_mapping[name] for name in names.split(",")] + return ";".join(ids) + + mapping = {} + with open(csv_file, "r") as fin: + reader = csv.reader(fin, delimiter=" ") + for i, row in enumerate(reader): + if i <= 2: + continue + key = row[0].replace(",", "") + mapping[key] = name2id(row[-1]) + return mapping + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--dataset-dir", type=str, default="downloads/audioset") + + parser.add_argument( + "--split", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "eval"], + ) + + parser.add_argument( + "--feat-output-dir", + type=str, + default="data/fbank", + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + dataset_dir = args.dataset_dir + split = args.split + feat_output_dir = args.feat_output_dir + + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + if split in ["balanced", "unbalanced"]: + csv_file = f"{dataset_dir}/{split}_train_segments.csv" + elif split == "eval": + csv_file = f"{dataset_dir}/eval_segments.csv" + else: + raise ValueError() + + class_indices_csv = f"{dataset_dir}/class_labels_indices.csv" + id_mapping = get_ID_mapping(class_indices_csv) + labels = parse_csv(csv_file, id_mapping) + + audio_files = glob.glob(f"{dataset_dir}/{split}/*.wav") + + new_cuts = [] + for i, audio in enumerate(audio_files): + cut_id = audio.split("/")[-1].split("_")[0] + recording = Recording.from_file(audio, cut_id) + cut = MonoCut( + id=cut_id, + start=0.0, + duration=recording.duration, + channel=0, + recording=recording, + ) + supervision = SupervisionSegment( + id=cut_id, + recording_id=cut.recording.id, + start=0.0, + channel=0, + duration=cut.duration, + ) + try: + supervision.audio_event = labels[cut_id] + except KeyError: + logging.info(f"No labels found for {cut_id}.") + continue + cut.supervisions = [supervision] + new_cuts.append(cut) + + if i % 100 == 0 and i: + logging.info(f"Processed {i} cuts until now.") + + cuts = CutSet.from_cuts(new_cuts) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + logging.info(f"Computing fbank features for {split}") + with get_executor() as ex: + cuts = cuts.compute_and_store_features( + extractor=extractor, + storage_path=f"{feat_output_dir}/{split}_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz" + + logging.info(f"Storing the manifest to {manifest_output_dir}") + cuts.to_jsonl(manifest_output_dir) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/audioset/AT/prepare.sh b/egs/audioset/AT/prepare.sh new file mode 100755 index 000000000..f7f73a008 --- /dev/null +++ b/egs/audioset/AT/prepare.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +# run step 0 to step 5 by default +stage=-1 +stop_stage=4 + +dl_dir=$PWD/download + +# we assume that you have your downloaded the AudioSet and placed +# it under $dl_dir/audioset, the folder structure should look like +# this: +# - $dl_dir/audioset +# - balanced +# - eval +# - unbalanced +# If you haven't downloaded the AudioSet, please refer to +# https://github.com/RicherMans/SAT/blob/main/datasets/audioset/1_download_audioset.sh. + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "Running prepare.sh" + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage 0: Download the necessary csv files" + if [ ! -e $dl_dir/audioset/.csv.done]; then + wget --continue "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv" -O "${dl_dir}/audioset/class_labels_indices.csv" + wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv -O "${dl_dir}/audioset/balanced_train_segments.csv" + wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/eval_segments.csv -O "${dl_dir}/audioset/eval_segments.csv" + touch $dl_dir/audioset/.csv.done + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set" + fbank_dir=data/fbank + if [! -e $fbank_dir/.balanced.done]; then + python local/generate_audioset_manifest.py \ + --dataset-dir $dl_dir/audioset \ + --split balanced \ + --feat-output-dir $fbank_dir + touch $fbank_dir/.balanced.done + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Construct the audioset manifest and compute the fbank features for unbalanced set" + fbank_dir=data/fbank + if [! -e $fbank_dir/.unbalanced.done]; then + python local/generate_audioset_manifest.py \ + --dataset-dir $dl_dir/audioset \ + --split unbalanced \ + --feat-output-dir $fbank_dir + touch $fbank_dir/.unbalanced.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Construct the audioset manifest and compute the fbank features for eval set" + fbank_dir=data/fbank + if [! -e $fbank_dir/.eval.done]; then + python local/generate_audioset_manifest.py \ + --dataset-dir $dl_dir/audioset \ + --split eval \ + --feat-output-dir $fbank_dir + touch $fbank_dir/.eval.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi diff --git a/egs/audioset/AT/shared b/egs/audioset/AT/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/audioset/AT/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py new file mode 100644 index 000000000..66497c1ca --- /dev/null +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -0,0 +1,420 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + AudioTaggingDataset, + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AudioSetATDatamodule: + """ + DataModule for k2 audio tagging (AT) experiments. + + + It contains all the common data pipeline modules used in AT + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in AT tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="AT data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--audioset-subset", + type=str, + default="balanced", + choices=["balanced", "full"], + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with audioset train/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ): + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = AudioTaggingDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = AudioTaggingDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = AudioTaggingDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = AudioTaggingDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = AudioTaggingDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def audioset_train_cuts(self) -> CutSet: + logging.info("About to get the audioset training cuts.") + balanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" + ) + if self.args.audioset_subset == "full": + unbalanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" + ) + cuts = CutSet.mux( + balanced_cuts, + unbalanced_cuts, + weights=[20000, 2000000], + stop_early=True, + ) + else: + cuts = balanced_cuts + return cuts + + @lru_cache() + def audioset_eval_cuts(self) -> CutSet: + logging.info("About to get audioset eval cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz" + ) diff --git a/egs/audioset/AT/zipformer/encoder_interface.py b/egs/audioset/AT/zipformer/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/audioset/AT/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/evaluate.py b/egs/audioset/AT/zipformer/evaluate.py new file mode 100644 index 000000000..b52a284d0 --- /dev/null +++ b/egs/audioset/AT/zipformer/evaluate.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0" + +./zipformer/evaluate.py \ + --epoch 50 \ + --avg 10 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + + +""" + +import argparse +import csv +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +import torch.nn.functional as F +from at_datamodule import AudioSetATDatamodule +from lhotse import load_manifest + +try: + from sklearn.metrics import average_precision_score +except Exception as ex: + raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn") +from train import add_model_arguments, get_model, get_params, str2multihot + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + + +def inference_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +): + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3, feature.shape + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + audio_event = supervisions["audio_event"] + + label, _ = str2multihot(audio_event) + label = label.detach().cpu() + + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) + # convert to probabilities between 0-1 + audio_logits = audio_logits.sigmoid().detach().cpu() + + return audio_logits, label + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict: + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + all_logits = [] + all_labels = [] + + for batch_idx, batch in enumerate(dl): + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + num_cuts += len(cut_ids) + + audio_logits, labels = inference_one_batch( + params=params, + model=model, + batch=batch, + ) + + all_logits.append(audio_logits) + all_labels.append(labels) + + if batch_idx % 20 == 1: + logging.info(f"Processed {num_cuts} cuts already.") + logging.info("Finish collecting audio logits") + + return all_logits, all_labels + + +@torch.no_grad() +def main(): + parser = get_parser() + AudioSetATDatamodule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "inference_audio_tagging" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Evaluation started") + + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info("About to create model") + + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + args.return_cuts = True + audioset = AudioSetATDatamodule(args) + + audioset_cuts = audioset.audioset_eval_cuts() + + audioset_dl = audioset.valid_dataloaders(audioset_cuts) + + test_sets = ["audioset_eval"] + + logits, labels = decode_dataset( + dl=audioset_dl, + params=params, + model=model, + ) + + logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy() + labels = torch.cat(labels, dim=0).long().detach().numpy() + + # compute the metric + mAP = average_precision_score( + y_true=labels, + y_score=logits, + ) + + logging.info(f"mAP for audioset eval is: {mAP}") + + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py new file mode 100755 index 000000000..af83c0e9c --- /dev/null +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/ +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict + +import k2 +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxAudioTagger(nn.Module): + """A wrapper for Zipformer audio tagger""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, classifier: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.classifier = classifier + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> torch.Tensor: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tensor containing: + - logits, A 2-D tensor of shape (N, num_classes) + + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C) + + logits = self.classifier(encoder_out) # (N, T, num_classes) + # Note that this is slightly different from model.py for better + # support of onnx + logits = logits.mean(dim=1) + return logits + + +def export_audio_tagging_model_onnx( + model: OnnxAudioTagger, + filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + model: + The input encoder model + filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 200, 80, dtype=torch.float32) + x_lens = torch.tensor([200], dtype=torch.int64) + + model = torch.jit.trace(model, (x, x_lens)) + + torch.onnx.export( + model, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["logits"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "logits": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2_at", + "version": "1", + "model_author": "k2-fsa", + "comment": "zipformer2 audio tagger", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + model = OnnxAudioTagger( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + classifier=model.classifier, + ) + + model_num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"total parameters: {model_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting audio tagging model") + model_filename = params.exp_dir / f"model-{suffix}.onnx" + export_audio_tagging_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported audio tagging model to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"model-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/audioset/AT/zipformer/export.py b/egs/audioset/AT/zipformer/export.py new file mode 100755 index 000000000..bdcf8b7dd --- /dev/null +++ b/egs/audioset/AT/zipformer/export.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --epoch 30 \ + --avg 9 + + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/evaluate.py \ + --exp-dir ./zipformer/exp \ + --use-averaged-model False \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 + +Check ./pretrained.py for its usage. + +""" + +import argparse +import logging +from pathlib import Path +from typing import Tuple + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class Classifier(nn.Module): + """A wrapper for audio tagging classifier""" + + def __init__(self, classifier: nn.Module) -> None: + super().__init__() + self.classifier = classifier + + def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor): + """ + Args: + encoder_out: + A 3-D tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + """ + logits = self.classifier(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) # mask the padding frames + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( + logits + ) # normalize the logits + + return logits + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + model.classifier = Classifier(model.classifier) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/audioset/AT/zipformer/jit_pretrained.py b/egs/audioset/AT/zipformer/jit_pretrained.py new file mode 100755 index 000000000..8e3afcb6f --- /dev/null +++ b/egs/audioset/AT/zipformer/jit_pretrained.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# 2024 Xiaoyu Yang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained.py \ + --nn-model-filename ./zipformer/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import csv +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--label-dict", + type=str, + help="""class_labels_indices.csv.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + # get the label dictionary + label_dict = {} + with open(args.label_dict, "r") as f: + reader = csv.reader(f, delimiter=",") + for i, row in enumerate(reader): + if i == 0: + continue + label_dict[int(row[0])] = row[2] + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + features=features, + feature_lengths=feature_lengths, + ) + + logits = model.classifier(encoder_out, encoder_out_lens) + + for filename, logit in zip(args.sound_files, logits): + topk_prob, topk_index = logit.sigmoid().topk(5) + topk_labels = [label_dict[index.item()] for index in topk_index] + logging.info( + f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" + ) + + logging.info("Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/audioset/AT/zipformer/model.py b/egs/audioset/AT/zipformer/model.py new file mode 100644 index 000000000..f189eac62 --- /dev/null +++ b/egs/audioset/AT/zipformer/model.py @@ -0,0 +1,157 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import random +from typing import List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from encoder_interface import EncoderInterface + +from icefall.utils import AttributeDict, make_pad_mask + + +class AudioTaggingModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + encoder_dim: int = 384, + num_events: int = 527, + ): + """An audio tagging model + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + encoder_dim: + Dimension of the encoder. + num_event: + The number of classes. + """ + super().__init__() + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.encoder_dim = encoder_dim + + self.classifier = nn.Sequential( + nn.Dropout(0.1), + nn.Linear(encoder_dim, num_events), + ) + + # for multi-class classification + self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum") + + def forward_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + target: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + target: + The ground truth label of audio events, could be many hot + Returns: + Return the binary crossentropy loss + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + # Forward the speaker module + logits = self.forward_audio_tagging( + encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) # (N, num_classes) + + loss = self.criterion(logits, target) + + return loss + + def forward_audio_tagging(self, encoder_out, encoder_out_lens): + """ + Args: + encoder_out: + A 3-D tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + A 3-D tensor of shape (N, num_classes). + """ + logits = self.classifier(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) # mask the padding frames + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( + logits + ) # normalize the logits + + return logits diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py new file mode 100755 index 000000000..c7753715a --- /dev/null +++ b/egs/audioset/AT/zipformer/onnx_pretrained.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +We use the pre-trained model from +https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/ +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/ +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - model-epoch-99-avg-1.onnx + +3. Run this file + +./zipformer/onnx_pretrained.py \ + --model-filename $repo/exp/model-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +""" + +import argparse +import csv +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--label-dict", + type=str, + help="""class_labels_indices.csv.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a Tensor: + - logits, its shape is (N, num_classes) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.model_filename, + ) + + # get the label dictionary + label_dict = {} + with open(args.label_dict, "r") as f: + reader = csv.reader(f, delimiter=",") + for i, row in enumerate(reader): + if i == 0: + continue + label_dict[int(row[0])] = row[2] + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + logits = model(features, feature_lengths) + + for filename, logit in zip(args.sound_files, logits): + topk_prob, topk_index = logit.sigmoid().topk(5) + topk_labels = [label_dict[index.item()] for index in topk_index] + logging.info( + f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" + ) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/audioset/AT/zipformer/optim.py b/egs/audioset/AT/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/audioset/AT/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/pretrained.py b/egs/audioset/AT/zipformer/pretrained.py new file mode 100755 index 000000000..60e4d0518 --- /dev/null +++ b/egs/audioset/AT/zipformer/pretrained.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp/epoch-xx.pt`. + +Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py +""" + + +import argparse +import csv +import logging +import math +from typing import List + +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--label-dict", + type=str, + help="""class_labels_indices.csv.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + # get the label dictionary + label_dict = {} + with open(params.label_dict, "r") as f: + reader = csv.reader(f, delimiter=",") + for i, row in enumerate(reader): + if i == 0: + continue + label_dict[int(row[0])] = row[2] + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward and predict the audio events + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) + + for filename, logit in zip(args.sound_files, logits): + topk_prob, topk_index = logit.sigmoid().topk(5) + topk_labels = [label_dict[index.item()] for index in topk_index] + logging.info( + f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" + ) + + logging.info("Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/audioset/AT/zipformer/scaling.py b/egs/audioset/AT/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/audioset/AT/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/scaling_converter.py b/egs/audioset/AT/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/audioset/AT/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/subsampling.py b/egs/audioset/AT/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/audioset/AT/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py new file mode 100644 index 000000000..0e234c59f --- /dev/null +++ b/egs/audioset/AT/zipformer/train.py @@ -0,0 +1,1186 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --audioset-subset full \ + --max-duration 1000 + + +""" + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from at_datamodule import AudioSetATDatamodule +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AudioTaggingModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model. Do not recommend to use this for AT", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--num-events", type=int, default=527, help="Number of sound events" + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def _str2modulelist(s: str, add_dot: bool = True): + if add_dot: + return [ss.strip() + "." for ss in s.split(",")] if s is not None else None + else: + return [ss.strip() for ss in s.split(",")] if s is not None else None + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_model(params: AttributeDict) -> nn.Module: + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + model = AudioTaggingModel( + encoder_embed=encoder_embed, + encoder=encoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + num_events=params.num_events, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + events = supervisions[ + "audio_event" + ] # the label indices are in CED format (https://github.com/RicherMans/CED) + labels, _ = str2multihot(events, n_classes=params.num_events) + labels = labels.to(device) + + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + with torch.set_grad_enabled(is_training): + loss = model( + x=feature, + x_lens=feature_lens, + target=labels, + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def str2multihot(events: List[str], n_classes=527, id_mapping=None): + # Convert strings separated by semi-colon to multi-hot class labels + # input: ["0;1", "1;2"] + # output: torch.tensor([[1,1,0], [0,1,1]]) + labels = [list(map(int, event.split(";"))) for event in events] + batch_size = len(labels) + out = torch.zeros(batch_size, n_classes) + + for i, label in enumerate(labels): + if id_mapping is not None: + label = [id_mapping[lb] for lb in label] + out[i, label] = 1 + + return out, labels + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = batch["inputs"].size(0) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs( + model, + lr=params.base_lr, + include_names=True, + ), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + audioset = AudioSetATDatamodule(args) + train_cuts = audioset.audioset_train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 30.0: + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = audioset.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = audioset.audioset_eval_cuts() + valid_dl = audioset.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch( + batch, + params=params, + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AudioSetATDatamodule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/audioset/AT/zipformer/zipformer.py b/egs/audioset/AT/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/audioset/AT/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From f5d781873334cec84e1248086b1bd8207132283e Mon Sep 17 00:00:00 2001 From: yh646492956 <35254755+yh646492956@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:16:12 +0800 Subject: [PATCH 150/216] fix run.sh script in wenetspeech KWS (#1584) Co-authored-by: Hao You <13182720519@sina.cn> --- egs/wenetspeech/KWS/run.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 8698e9fcc..8455cc5be 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -119,11 +119,11 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then # We recommend to start from an averaged model finetune_ckpt=zipformer/exp/pretrained.pt - ./zipformer/finetune.py \ + python ./zipformer/finetune.py \ --world-size 4 \ --num-epochs 10 \ --start-epoch 1 \ - --exp-dir zipformer/exp_finetune + --exp-dir zipformer/exp_finetune \ --lang-dir ./data/lang_partial_tone \ --pinyin-type partial_with_tone \ --use-fp16 1 \ From fa5d861af08eabbcd37b5326d4cf1e8079d5ca07 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 9 Apr 2024 17:45:00 +0800 Subject: [PATCH 151/216] Add CI test for the AudioSet recipe. (#1585) --- .github/scripts/audioset/AT/run.sh | 94 +++++++++++++ .github/scripts/docker/Dockerfile | 2 + .github/workflows/audioset.yml | 137 +++++++++++++++++++ docker/torch1.12.1-cuda11.3.dockerfile | 2 + docker/torch1.13.0-cuda11.6.dockerfile | 2 + docker/torch1.9.0-cuda10.2.dockerfile | 2 + docker/torch2.0.0-cuda11.7.dockerfile | 2 + docker/torch2.1.0-cuda11.8.dockerfile | 2 + docker/torch2.1.0-cuda12.1.dockerfile | 2 + docker/torch2.2.0-cuda11.8.dockerfile | 2 + docker/torch2.2.0-cuda12.1.dockerfile | 2 + docker/torch2.2.1-cuda11.8.dockerfile | 2 + docker/torch2.2.1-cuda12.1.dockerfile | 2 + docker/torch2.2.2-cuda11.8.dockerfile | 2 + docker/torch2.2.2-cuda12.1.dockerfile | 2 + egs/audioset/AT/zipformer/export-onnx.py | 91 ++++++------ egs/audioset/AT/zipformer/export.py | 7 +- egs/audioset/AT/zipformer/jit_pretrained.py | 21 ++- egs/audioset/AT/zipformer/onnx_pretrained.py | 55 +++----- egs/audioset/AT/zipformer/pretrained.py | 34 ++--- requirements.txt | 9 +- 21 files changed, 360 insertions(+), 114 deletions(-) create mode 100755 .github/scripts/audioset/AT/run.sh create mode 100644 .github/workflows/audioset.yml diff --git a/.github/scripts/audioset/AT/run.sh b/.github/scripts/audioset/AT/run.sh new file mode 100755 index 000000000..87856b64d --- /dev/null +++ b/.github/scripts/audioset/AT/run.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash + +set -ex + +python3 -m pip install onnxoptimizer onnxsim + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/audioset/AT + +function test_pretrained() { + repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 + repo=$(basename $repo_url) + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo/exp + git lfs pull --include pretrained.pt + ln -s pretrained.pt epoch-99.pt + ls -lh + popd + + log "test pretrained.pt" + + python3 zipformer/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav + + log "test jit export" + ls -lh $repo/exp/ + python3 zipformer/export.py \ + --exp-dir $repo/exp \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --jit 1 + ls -lh $repo/exp/ + + log "test jit models" + python3 zipformer/jit_pretrained.py \ + --nn-model-filename $repo/exp/jit_script.pt \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav + + log "test onnx export" + ls -lh $repo/exp/ + python3 zipformer/export-onnx.py \ + --exp-dir $repo/exp \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + + ls -lh $repo/exp/ + + pushd $repo/exp/ + mv model-epoch-99-avg-1.onnx model.onnx + mv model-epoch-99-avg-1.int8.onnx model.int8.onnx + popd + + ls -lh $repo/exp/ + + log "test onnx models" + for m in model.onnx model.int8.onnx; do + log "$m" + python3 zipformer/onnx_pretrained.py \ + --model-filename $repo/exp/model.onnx \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav + done + + log "prepare data for uploading to huggingface" + dst=/icefall/model-onnx + mkdir -p $dst + cp -v $repo/exp/*.onnx $dst/ + cp -v $repo/data/* $dst/ + cp -av $repo/test_wavs $dst + + ls -lh $dst + ls -lh $dst/test_wavs +} + +test_pretrained diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index f64446e7e..15f49f826 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -49,6 +49,8 @@ RUN pip install --no-cache-dir \ multi_quantization \ numba \ numpy \ + onnxoptimizer \ + onnxsim \ onnx \ onnxmltools \ onnxruntime \ diff --git a/.github/workflows/audioset.yml b/.github/workflows/audioset.yml new file mode 100644 index 000000000..280ef8f8e --- /dev/null +++ b/.github/workflows/audioset.yml @@ -0,0 +1,137 @@ +name: audioset + +on: + push: + branches: + - master + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: audioset-${{ github.ref }} + cancel-in-progress: true + +jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + audioset: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + ls -lh + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run tests + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + .github/scripts/audioset/AT/run.sh + + - name: Show model files + shell: bash + run: | + sudo chown -R runner ./model-onnx + ls -lh ./model-onnx + chmod -x ./model-onnx/class_labels_indices.csv + + echo "----------" + ls -lh ./model-onnx/* + + - name: Upload model to huggingface + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + rm -rf huggingface + export GIT_LFS_SKIP_SMUDGE=1 + + git clone https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 huggingface + cd huggingface + git fetch + git pull + git merge -m "merge remote" --ff origin main + cp ../model-onnx/*.onnx ./ + cp ../model-onnx/*.csv ./ + cp -a ../model-onnx/test_wavs ./ + ls -lh + git add . + git status + git commit -m "update models" + git status + + git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 main || true + rm -rf huggingface + + - name: Prepare for release + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + shell: bash + run: | + d=sherpa-onnx-zipformer-audio-tagging-2024-04-09 + mv ./model-onnx $d + tar cjvf ${d}.tar.bz2 $d + ls -lh + + - name: Release exported onnx models + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: sherpa-onnx-*.tar.bz2 + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: audio-tagging-models + diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index 33ecbf4d1..9815a8ec7 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -55,6 +55,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index b4d62b0bc..d13d2a7cb 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -55,6 +55,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index 4d2d3058a..5936fe06a 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -69,6 +69,8 @@ RUN pip uninstall -y tqdm && \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index 31ff09ac6..e2e27b55d 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index 83b64a8d2..de1e07e69 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index ec366a898..89303797a 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.2.0-cuda11.8.dockerfile b/docker/torch2.2.0-cuda11.8.dockerfile index 143f0e066..3364477a8 100644 --- a/docker/torch2.2.0-cuda11.8.dockerfile +++ b/docker/torch2.2.0-cuda11.8.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.2.0-cuda12.1.dockerfile b/docker/torch2.2.0-cuda12.1.dockerfile index c6d5a771f..3cc41902d 100644 --- a/docker/torch2.2.0-cuda12.1.dockerfile +++ b/docker/torch2.2.0-cuda12.1.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.2.1-cuda11.8.dockerfile b/docker/torch2.2.1-cuda11.8.dockerfile index d874134d7..76b785622 100644 --- a/docker/torch2.2.1-cuda11.8.dockerfile +++ b/docker/torch2.2.1-cuda11.8.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.2.1-cuda12.1.dockerfile b/docker/torch2.2.1-cuda12.1.dockerfile index 6e4ef290a..55bdfa4d7 100644 --- a/docker/torch2.2.1-cuda12.1.dockerfile +++ b/docker/torch2.2.1-cuda12.1.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.2.2-cuda11.8.dockerfile b/docker/torch2.2.2-cuda11.8.dockerfile index bca40a065..02de82c50 100644 --- a/docker/torch2.2.2-cuda11.8.dockerfile +++ b/docker/torch2.2.2-cuda11.8.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/docker/torch2.2.2-cuda12.1.dockerfile b/docker/torch2.2.2-cuda12.1.dockerfile index 4fb8946e7..44ad38b8e 100644 --- a/docker/torch2.2.2-cuda12.1.dockerfile +++ b/docker/torch2.2.2-cuda12.1.dockerfile @@ -56,6 +56,8 @@ RUN pip install --no-cache-dir \ onnx \ onnxruntime \ onnxmltools \ + onnxoptimizer \ + onnxsim \ multi_quantization \ typeguard \ numpy \ diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py index af83c0e9c..9476dac62 100755 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -6,56 +6,28 @@ """ This script exports a transducer model from PyTorch to ONNX. -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. +Usage of this script: -1. Download the pre-trained model + repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 + repo=$(basename $repo_url) + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo/exp + git lfs pull --include pretrained.pt + ln -s pretrained.pt epoch-99.pt + popd -cd egs/librispeech/ASR + python3 zipformer/export-onnx.py \ + --exp-dir $repo/exp \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 -repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/ -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) + pushd $repo/exp + mv model-epoch-99-avg-1.onnx model.onnx + mv model-epoch-99-avg-1.int8.onnx model.int8.onnx + popd -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal False \ - --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py and ./onnx_check.py for how to +See ./onnx_pretrained.py use the exported ONNX models. """ @@ -66,9 +38,11 @@ from typing import Dict import k2 import onnx +import onnxoptimizer import torch import torch.nn as nn from onnxruntime.quantization import QuantType, quantize_dynamic +from onnxsim import simplify from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params from zipformer import Zipformer2 @@ -261,6 +235,29 @@ def export_audio_tagging_model_onnx( add_meta_data(filename=filename, meta_data=meta_data) +def optimize_model(filename): + # see + # https://github.com/microsoft/onnxruntime/issues/1899#issuecomment-534806537 + # and + # https://github.com/onnx/onnx/issues/582#issuecomment-937788108 + # and + # https://github.com/onnx/optimizer/issues/110 + # and + # https://qiita.com/Yossy_Hal/items/34f3b2aef2199baf7f5f + passes = ["eliminate_unused_initializer"] + onnx_model = onnx.load(filename) + onnx_model = onnxoptimizer.optimize(onnx_model, passes) + + model_simp, check = simplify(onnx_model) + if check: + logging.info("Simplified the model!") + onnx_model = model_simp + else: + logging.info("Failed to simplify the model!") + + onnx.save(onnx_model, filename) + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -389,6 +386,7 @@ def main(): model_filename, opset_version=opset_version, ) + optimize_model(model_filename) logging.info(f"Exported audio tagging model to {model_filename}") # Generate int8 quantization models @@ -403,6 +401,7 @@ def main(): op_types_to_quantize=["MatMul"], weight_type=QuantType.QInt8, ) + optimize_model(model_filename_int8) if __name__ == "__main__": diff --git a/egs/audioset/AT/zipformer/export.py b/egs/audioset/AT/zipformer/export.py index bdcf8b7dd..6ceeca8de 100755 --- a/egs/audioset/AT/zipformer/export.py +++ b/egs/audioset/AT/zipformer/export.py @@ -25,7 +25,7 @@ Usage: -Note: This is a example for librispeech dataset, if you are using different +Note: This is an example for AudioSet dataset, if you are using different dataset, you should change the argument values according to your dataset. (1) Export to torchscript model using torch.jit.script() @@ -42,6 +42,7 @@ load it by `torch.jit.load("jit_script.pt")`. Check ./jit_pretrained.py for its usage. Check https://github.com/k2-fsa/sherpa +and https://github.com/k2-fsa/sherpa-onnx for how to use the exported models outside of icefall. (2) Export `model.state_dict()` @@ -55,13 +56,13 @@ for how to use the exported models outside of icefall. It will generate a file `pretrained.pt` in the given `exp_dir`. You can later load it by `icefall.checkpoint.load_checkpoint()`. -To use the generated file with `zipformer/decode.py`, +To use the generated file with `zipformer/evaluate.py`, you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt - cd /path/to/egs/librispeech/ASR + cd /path/to/egs/audioset/AT ./zipformer/evaluate.py \ --exp-dir ./zipformer/exp \ --use-averaged-model False \ diff --git a/egs/audioset/AT/zipformer/jit_pretrained.py b/egs/audioset/AT/zipformer/jit_pretrained.py index 8e3afcb6f..403308fcf 100755 --- a/egs/audioset/AT/zipformer/jit_pretrained.py +++ b/egs/audioset/AT/zipformer/jit_pretrained.py @@ -28,10 +28,20 @@ You can use the following command to get the exported models: Usage of this script: -./zipformer/jit_pretrained.py \ - --nn-model-filename ./zipformer/exp/cpu_jit.pt \ - /path/to/foo.wav \ - /path/to/bar.wav + repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 + repo=$(basename $repo_url) + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo/exp + git lfs pull --include jit_script.pt + popd + + python3 zipformer/jit_pretrained.py \ + --nn-model-filename $repo/exp/jit_script.pt \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav """ import argparse @@ -168,7 +178,8 @@ def main(): topk_prob, topk_index = logit.sigmoid().topk(5) topk_labels = [label_dict[index.item()] for index in topk_index] logging.info( - f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" + f"{filename}: Top 5 predicted labels are {topk_labels} with " + f"probability of {topk_prob.tolist()}" ) logging.info("Done") diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py index c7753715a..1d3093d99 100755 --- a/egs/audioset/AT/zipformer/onnx_pretrained.py +++ b/egs/audioset/AT/zipformer/onnx_pretrained.py @@ -17,48 +17,25 @@ # limitations under the License. """ This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: -We use the pre-trained model from -https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/ -as an example to show how to use this file. +Usage of this script: -1. Download the pre-trained model + repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 + repo=$(basename $repo_url) + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo/exp + git lfs pull --include "*.onnx" + popd -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/ -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -It will generate the following 3 files inside $repo/exp: - - - model-epoch-99-avg-1.onnx - -3. Run this file - -./zipformer/onnx_pretrained.py \ - --model-filename $repo/exp/model-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav + for m in model.onnx model.int8.onnx; do + python3 zipformer/onnx_pretrained.py \ + --model-filename $repo/exp/model.onnx \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav + done """ import argparse diff --git a/egs/audioset/AT/zipformer/pretrained.py b/egs/audioset/AT/zipformer/pretrained.py index 60e4d0518..bdbd799fa 100755 --- a/egs/audioset/AT/zipformer/pretrained.py +++ b/egs/audioset/AT/zipformer/pretrained.py @@ -18,27 +18,25 @@ This script loads a checkpoint and uses it to decode waves. You can generate the checkpoint with the following command: -Note: This is a example for librispeech dataset, if you are using different +Note: This is an example for the AudioSet dataset, if you are using different dataset, you should change the argument values according to your dataset. - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - Usage of this script: -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - /path/to/foo.wav \ - /path/to/bar.wav + repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 + repo=$(basename $repo_url) + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo/exp + git lfs pull --include pretrained.pt + popd - -You can also use `./zipformer/exp/epoch-xx.pt`. - -Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py + python3 zipformer/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav """ @@ -189,7 +187,8 @@ def main(): topk_prob, topk_index = logit.sigmoid().topk(5) topk_labels = [label_dict[index.item()] for index in topk_index] logging.info( - f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" + f"{filename}: Top 5 predicted labels are {topk_labels} with " + f"probability of {topk_prob.tolist()}" ) logging.info("Done") @@ -199,4 +198,5 @@ if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/requirements.txt b/requirements.txt index 6bafa6aca..8410453f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,13 +8,14 @@ pypinyin==0.50.0 tensorboard typeguard dill -onnx==1.15.0 -onnxruntime==1.16.3 +onnx>=1.15.0 +onnxruntime>=1.16.3 +onnxoptimizer # style check session: black==22.3.0 isort==5.10.1 -flake8==5.0.4 +flake8==5.0.4 # cantonese word segment support -pycantonese==3.4.0 \ No newline at end of file +pycantonese==3.4.0 From ba5b2e854bcf9e8ce3cde62d05399b7701b28ec5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Apr 2024 09:03:30 +0800 Subject: [PATCH 152/216] Return probs in audio tagging onnx models (#1586) --- egs/audioset/AT/zipformer/export-onnx.py | 10 ++++++---- egs/audioset/AT/zipformer/onnx_pretrained.py | 21 ++++++++++---------- requirements.txt | 1 + 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py index 9476dac62..24b7717b4 100755 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -164,7 +164,7 @@ class OnnxAudioTagger(nn.Module): A 1-D tensor of shape (N,). Its dtype is torch.int64 Returns: Return a tensor containing: - - logits, A 2-D tensor of shape (N, num_classes) + - probs, A 2-D tensor of shape (N, num_classes) """ x, x_lens = self.encoder_embed(x, x_lens) @@ -177,7 +177,8 @@ class OnnxAudioTagger(nn.Module): # Note that this is slightly different from model.py for better # support of onnx logits = logits.mean(dim=1) - return logits + probs = logits.sigmoid() + return probs def export_audio_tagging_model_onnx( @@ -220,15 +221,16 @@ def export_audio_tagging_model_onnx( dynamic_axes={ "x": {0: "N", 1: "T"}, "x_lens": {0: "N"}, - "logits": {0: "N"}, + "probs": {0: "N"}, }, ) meta_data = { - "model_type": "zipformer2_at", + "model_type": "zipformer2", "version": "1", "model_author": "k2-fsa", "comment": "zipformer2 audio tagger", + "url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer", } logging.info(f"meta_data: {meta_data}") diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py index 1d3093d99..82fa3d45b 100755 --- a/egs/audioset/AT/zipformer/onnx_pretrained.py +++ b/egs/audioset/AT/zipformer/onnx_pretrained.py @@ -20,17 +20,17 @@ This script loads ONNX models and uses them to decode waves. Usage of this script: - repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 + repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 repo=$(basename $repo_url) - GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url - pushd $repo/exp + git clone $repo_url + pushd $repo git lfs pull --include "*.onnx" popd for m in model.onnx model.int8.onnx; do python3 zipformer/onnx_pretrained.py \ - --model-filename $repo/exp/model.onnx \ - --label-dict $repo/data/class_labels_indices.csv \ + --model-filename $repo/model.onnx \ + --label-dict $repo/class_labels_indices.csv \ $repo/test_wavs/1.wav \ $repo/test_wavs/2.wav \ $repo/test_wavs/3.wav \ @@ -125,7 +125,7 @@ class OnnxModel: A 2-D tensor of shape (N,). Its dtype is torch.int64 Returns: Return a Tensor: - - logits, its shape is (N, num_classes) + - probs, its shape is (N, num_classes) """ out = self.model.run( [ @@ -208,13 +208,14 @@ def main(): ) feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - logits = model(features, feature_lengths) + probs = model(features, feature_lengths) - for filename, logit in zip(args.sound_files, logits): - topk_prob, topk_index = logit.sigmoid().topk(5) + for filename, prob in zip(args.sound_files, probs): + topk_prob, topk_index = prob.topk(5) topk_labels = [label_dict[index.item()] for index in topk_index] logging.info( - f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" + f"{filename}: Top 5 predicted labels are {topk_labels} with " + f"probability of {topk_prob.tolist()}" ) logging.info("Decoding Done") diff --git a/requirements.txt b/requirements.txt index 8410453f9..226adaba1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ dill onnx>=1.15.0 onnxruntime>=1.16.3 onnxoptimizer +onnxsim # style check session: black==22.3.0 From ed6bc200e37aaea0129ae32095642c096d4ffad5 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 11 Apr 2024 19:35:25 +0800 Subject: [PATCH 153/216] Update train.py (#1590) --- egs/librispeech/ASR/zipformer/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 1111d32ab..04caf2fd8 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -966,7 +966,10 @@ def train_one_epoch( scaler.step(optimizer) scaler.update() optimizer.zero_grad() - except: # noqa + except Exception as e: + logging.info( + f"Caught exception: {e}." + ) save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise From 9f8f0bceb5e734cfebf3ba8af213c0b4715a266c Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 20 Apr 2024 23:02:02 +0900 Subject: [PATCH 154/216] Update prepare.sh (#1601) --- egs/multi_zh-hans/ASR/prepare.sh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index fa515ed50..d1c7e695c 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -299,6 +299,15 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then log "Compute KeSpeech fbank for test/dev" ./local/compute_fbank_kespeech_dev_test.py + if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then + pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz") + lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz + fi + if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then + pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz") + lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz + fi + touch data/fbank/.kespeech.done fi fi From 368b7d10a7058d920ce3feaa231c9fc0760eec4b Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 24 Apr 2024 14:31:25 +0800 Subject: [PATCH 155/216] clear log handlers before setup (#1603) --- icefall/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/icefall/utils.py b/icefall/utils.py index 2cb2edf93..ec6aee6d0 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -110,6 +110,13 @@ def str2bool(v): raise argparse.ArgumentTypeError("Boolean value expected.") +def clear_log_handlers(): + logger = logging.getLogger() + handlers = logger.handlers[:] + for handler in handlers: + logger.removeHandler(handler) + + def setup_logger( log_filename: Pathlike, log_level: str = "info", @@ -126,6 +133,8 @@ def setup_logger( use_console: True to also print logs to console. """ + clear_log_handlers() + now = datetime.now() date_time = now.strftime("%Y-%m-%d-%H-%M-%S") if dist.is_available() and dist.is_initialized(): From df36f93bd81a61cf1bcff23ea465292b33b3a268 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 24 Apr 2024 17:00:42 +0800 Subject: [PATCH 156/216] add small-scaled model for audio tagging (#1604) --- egs/audioset/AT/RESULTS.md | 51 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/egs/audioset/AT/RESULTS.md b/egs/audioset/AT/RESULTS.md index 0c75dfe4e..0128b7018 100644 --- a/egs/audioset/AT/RESULTS.md +++ b/egs/audioset/AT/RESULTS.md @@ -5,6 +5,8 @@ See for more details [zipformer](./zipformer) +#### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -42,3 +44,52 @@ python zipformer/evaluate.py \ --exp-dir zipformer/exp_at_as_full \ --max-duration 500 ``` + + +#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +The model achieves the following mean averaged precision on AudioSet: + +| Model | mAP | +| ------ | ------- | +| Zipformer-S-AT | 45.1 | + +The training command is: + +```bash +export CUDA_VISIBLE_DEVICES="4,5,6,7" +subset=full + +python zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --exp-dir zipformer/exp_small_at_as_${subset} \ + --start-epoch 1 \ + --use-fp16 1 \ + --num-events 527 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --audioset-subset $subset \ + --max-duration 1200 \ + --enable-musan True \ + --master-port 13455 +``` + +The evaluation command is: + +```bash +python zipformer/evaluate.py \ + --epoch 31 \ + --avg 4 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --exp-dir zipformer/exp_small_at_as_full \ + --max-duration 500 +``` \ No newline at end of file From 25cabb76635d98a65ad1c4d7733052b5b00ad45e Mon Sep 17 00:00:00 2001 From: zzasdf <68544676+zzasdf@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:40:07 +0800 Subject: [PATCH 157/216] fix error in padding computing (#1607) --- egs/librispeech/SSL/zipformer/hubert_ce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/SSL/zipformer/hubert_ce.py b/egs/librispeech/SSL/zipformer/hubert_ce.py index ba4e1cddd..1ac368a1d 100644 --- a/egs/librispeech/SSL/zipformer/hubert_ce.py +++ b/egs/librispeech/SSL/zipformer/hubert_ce.py @@ -429,7 +429,7 @@ class HubertModel(nn.Module): # padding_mask: (B, T), bool # mask_indices: (B, T), bool x = x.transpose(0, 1) - x, x_lens = self.encoder(x, ~padding_mask.sum(dim=-1)) + x, x_lens = self.encoder(x, (~padding_mask).sum(dim=-1)) x = x.transpose(0, 1) if features_only: From 9a17f4ce410d23080c26c7e2257e9a14f312862d Mon Sep 17 00:00:00 2001 From: Dongji Gao Date: Thu, 25 Apr 2024 12:55:44 -0400 Subject: [PATCH 158/216] add OTC related scripts using phone as units instead of BPEs (#1602) * add otc related scripts using phone instead of bpe --- .../WSASR/conformer_ctc2/decode_phone.py | 592 +++++++++ .../WSASR/conformer_ctc2/train_phone.py | 1124 +++++++++++++++++ egs/librispeech/WSASR/local/download_lm.py | 146 +++ .../WSASR/local/prepare_otc_lang.py | 469 +++++++ egs/librispeech/WSASR/prepare.sh | 44 +- egs/librispeech/WSASR/shared | 1 + icefall/otc_phone_graph_compiler.py | 232 ++++ 7 files changed, 2599 insertions(+), 9 deletions(-) create mode 100755 egs/librispeech/WSASR/conformer_ctc2/decode_phone.py create mode 100755 egs/librispeech/WSASR/conformer_ctc2/train_phone.py create mode 100755 egs/librispeech/WSASR/local/download_lm.py create mode 100755 egs/librispeech/WSASR/local/prepare_otc_lang.py create mode 120000 egs/librispeech/WSASR/shared create mode 100644 icefall/otc_phone_graph_compiler.py diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py new file mode 100755 index 000000000..b6b1cb020 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Quandong Wang) +# 2023 Johns Hopkins University (Author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import get_lattice, one_best_decoding +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--otc-token", + type=str, + default="", + help="OTC token", + ) + + parser.add_argument( + "--blank-bias", + type=float, + default=0, + help="bias (log-prob) added to blank token during decoding", + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--method", + type=str, + default="ctc-greedy-search", + help="""Decoding method. + Supported values are: + - (0) 1best. Extract the best path from the decoding lattice as the + decoding result. + """, + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=0, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_phone", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "feature_dim": 80, + "nhead": 8, + "dim_feedforward": 2048, + "encoder_dim": 512, + "num_encoder_layers": 12, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + device = HLG.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + nnet_output[:, :, 0] += params.blank_bias + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="trunc", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="trunc", + ), + ), + 1, + ).to(torch.int32) + + decoding_graph = HLG + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor + 2, + ) + + if params.method in ["1best"]: + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + + return {key: hyps} + else: + assert False, f"Unsupported decoding method: {params.method}" + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + batch=batch, + word_table=word_table, + G=G, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + # remove otc_token from decoding units + max_token_id = len(lexicon.tokens) - 1 + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + params.num_classes = num_classes + + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = HLG.to(device) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.encoder_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + word_table=lexicon.word_table, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py new file mode 100755 index 000000000..b276d0587 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -0,0 +1,1124 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conformer_ctc2/train.py \ + --world-size 4 \ + --manifest-dir data/ssl \ + --train-manifest librispeech_cuts_train-clean-100_0.17_0.17_0.17.jsonl.gz \ + --exp-dir conformer_ctc2/exp \ + --lang-dir data/lang_bpe_200 \ + --otc-token "" \ + --feature-dim 768 \ + --allow-bypass-arc true \ + --allow-self-loop-arc true \ + --initial-bypass-weight -19 \ + --initial-self-loop-weight 3.75 \ + --bypass-weight-decay 0.975 \ + --self-loop-weight-decay 0.999 \ + --show-alignment true +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.decode import one_best_decoding +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions_otc, + get_texts, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_200", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="""Number of features extracted in feature extraction stage.last dimension of feature vector. + 80 when using fbank features and 768 or 1024 whn using wave2vec""", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.0, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=0, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=10, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--otc-token", + type=str, + default="_", + help="OTC token", + ) + + parser.add_argument( + "--allow-bypass-arc", + type=str2bool, + default=True, + help="""Whether to add bypass arc to training graph for substitution + and insertion errors (wrong or extra words in the transcript).""", + ) + + parser.add_argument( + "--allow-self-loop-arc", + type=str2bool, + default=True, + help="""Whether to self-loop bypass arc to training graph for deletion errors + (missing words in the transcript).""", + ) + + parser.add_argument( + "--initial-bypass-weight", + type=float, + default=0.0, + help="Initial weight associated with bypass arc", + ) + + parser.add_argument( + "--initial-self-loop-weight", + type=float, + default=0.0, + help="Initial weight associated with self-loop arc", + ) + + parser.add_argument( + "--bypass-weight-decay", + type=float, + default=1.0, + help="""Weight decay factor of bypass arc weight: + bypass_arc_weight = intial_bypass_weight * bypass_weight_decay ^ ith-epoch""", + ) + + parser.add_argument( + "--self-loop-weight-decay", + type=float, + default=1.0, + help="""Weight decay factor of self-loop arc weight: + self_loop_arc_weight = intial_self_loop_weight * self_loop_weight_decay ^ ith-epoch""", + ) + + parser.add_argument( + "--show-alignment", + type=str2bool, + default=True, + help="Whether to print OTC alignment during training", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 1, + "reset_interval": 200, + "valid_interval": 800, # For the 100h subset, use 800 + "alignment_interval": 100, + # parameters for conformer + "subsampling_factor": 4, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for ctc loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + batch: dict, + graph_compiler: OtcPhoneTrainingGraphCompiler, + is_training: bool, + warmup: float = 2.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute OTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model( + feature, supervisions, warmup=warmup + ) + # Set the probability of OTC token as the average of non-blank tokens + # under the assumption that blank is the first and + # OTC token is the last token in tokens.txt + _, _, V = nnet_output.shape + + otc_token_log_prob = torch.logsumexp( + nnet_output[:, :, 1:], dim=-1, keepdim=True + ) - torch.log(torch.tensor([V - 1])).to(device) + + nnet_output = torch.cat([nnet_output, otc_token_log_prob], dim=-1) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts, utt_ids, verbatim_texts = encode_supervisions_otc( + supervisions, subsampling_factor=params.subsampling_factor + ) + + bypass_weight = graph_compiler.initial_bypass_weight * ( + graph_compiler.bypass_weight_decay ** (params.cur_epoch - 1) + ) + self_loop_weight = graph_compiler.initial_self_loop_weight * ( + graph_compiler.self_loop_weight_decay ** (params.cur_epoch - 1) + ) + + decoding_graph = graph_compiler.compile( + texts=texts, + allow_bypass_arc=params.allow_bypass_arc, + allow_self_loop_arc=params.allow_self_loop_arc, + bypass_weight=bypass_weight, + self_loop_weight=self_loop_weight, + ) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=3, + ) + + otc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + assert params.att_rate == 0.0 + loss = otc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["otc_loss"] = otc_loss.detach().cpu().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + if params.show_alignment: + if params.batch_idx_train % params.alignment_interval == 0: + for index, utt_id in enumerate(utt_ids): + verbatim_text = verbatim_texts[index] + utt_id = utt_ids[index] + + lattice = k2.intersect_dense( + decoding_graph, + dense_fsa_vec, + params.beam_size, + ) + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + hyp_ids = get_texts(best_path)[index] + hyp_text_list = [graph_compiler.word_table[i] for i in hyp_ids] + hyp_text = " ".join(hyp_text_list) + + logging.info(f"[utterance id]: {utt_id}") + logging.info(f"[verbatim text]: {verbatim_text}") + logging.info(f"[best alignment]: {hyp_text}") + logging.info(bypass_weight) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: OtcPhoneTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + graph_compiler: OtcPhoneTrainingGraphCompiler, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + # scaler.scale(loss).backward() + + try: + # loss.backward() + scaler.scale(loss).backward() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error(f"failing batch size:{batch_size} ") + raise + + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + if loss_info["otc_loss"] == float("inf"): + logging.error("Your loss contains inf, something goes wrong") + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + lexicon = Lexicon(params.lang_dir) + graph_compiler = OtcPhoneTrainingGraphCompiler( + lexicon, + otc_token=params.otc_token, + device=device, + initial_bypass_weight=params.initial_bypass_weight, + initial_self_loop_weight=params.initial_self_loop_weight, + bypass_weight_decay=params.bypass_weight_decay, + self_loop_weight_decay=params.self_loop_weight_decay, + ) + + # remove OTC token as it is the average of all non-blank tokens + max_token_id = graph_compiler.get_max_token_id() - 1 + # add blank + num_classes = max_token_id + 1 + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.encoder_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + print(model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + graph_compiler=graph_compiler, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: OtcPhoneTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + warmup=0.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.otc_token = f"{args.otc_token}" + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/download_lm.py b/egs/librispeech/WSASR/local/download_lm.py new file mode 100755 index 000000000..5a36ff2a9 --- /dev/null +++ b/egs/librispeech/WSASR/local/download_lm.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file downloads the following LibriSpeech LM files: + + - 3-gram.pruned.1e-7.arpa.gz + - 4-gram.arpa.gz + - librispeech-vocab.txt + - librispeech-lexicon.txt + - librispeech-lm-norm.txt.gz + +from http://www.openslr.org/resources/11 +and save them in the user provided directory. + +Files are not re-downloaded if they already exist. + +Usage: + ./local/download_lm.py --out-dir ./download/lm +""" + +import argparse +import gzip +import logging +import os +import shutil +from pathlib import Path + +from tqdm.auto import tqdm + + +# This function is copied from lhotse +def tqdm_urlretrieve_hook(t): + """Wraps tqdm instance. + Don't forget to close() or __exit__() + the tqdm instance once you're done with it (easiest using `with` syntax). + Example + ------- + >>> from urllib.request import urlretrieve + >>> with tqdm(...) as t: + ... reporthook = tqdm_urlretrieve_hook(t) + ... urlretrieve(..., reporthook=reporthook) + + Source: https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + """ + last_b = [0] + + def update_to(b=1, bsize=1, tsize=None): + """ + b : int, optional + Number of blocks transferred so far [default: 1]. + bsize : int, optional + Size of each block (in tqdm units) [default: 1]. + tsize : int, optional + Total size (in tqdm units). If [default: None] or -1, + remains unchanged. + """ + if tsize not in (None, -1): + t.total = tsize + displayed = t.update((b - last_b[0]) * bsize) + last_b[0] = b + return displayed + + return update_to + + +# This function is copied from lhotse +def urlretrieve_progress(url, filename=None, data=None, desc=None): + """ + Works exactly like urllib.request.urlretrieve, but attaches a tqdm hook to + display a progress bar of the download. + Use "desc" argument to display a user-readable string that informs what is + being downloaded. + """ + from urllib.request import urlretrieve + + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=desc) as t: + reporthook = tqdm_urlretrieve_hook(t) + return urlretrieve(url=url, filename=filename, reporthook=reporthook, data=data) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=str, help="Output directory.") + + args = parser.parse_args() + return args + + +def main(out_dir: str): + url = "http://www.openslr.org/resources/11" + out_dir = Path(out_dir) + + files_to_download = ( + "3-gram.pruned.1e-7.arpa.gz", + "4-gram.arpa.gz", + "librispeech-vocab.txt", + "librispeech-lexicon.txt", + "librispeech-lm-norm.txt.gz", + ) + + for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): + filename = out_dir / f + if filename.is_file() is False: + urlretrieve_progress( + f"{url}/{f}", + filename=filename, + desc=f"Downloading {filename}", + ) + else: + logging.info(f"{filename} already exists - skipping") + + if ".gz" in str(filename): + unzipped = Path(os.path.splitext(filename)[0]) + if unzipped.is_file() is False: + with gzip.open(filename, "rb") as f_in: + with open(unzipped, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + else: + logging.info(f"{unzipped} already exist - skipping") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + logging.info(f"out_dir: {args.out_dir}") + + main(out_dir=args.out_dir) diff --git a/egs/librispeech/WSASR/local/prepare_otc_lang.py b/egs/librispeech/WSASR/local/prepare_otc_lang.py new file mode 100755 index 000000000..01865b865 --- /dev/null +++ b/egs/librispeech/WSASR/local/prepare_otc_lang.py @@ -0,0 +1,469 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2024 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import logging +import math +import re +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch + +from icefall.lexicon import write_lexicon +from icefall.utils import str2bool + +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--otc-token", + type=str, + default="", + help="The OTC token in lexicon", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + """, + ) + + return parser.parse_args() + + +def read_lexicon( + filename: str, +) -> List[Tuple[str, List[str]]]: + """Read a lexicon from `filename`. + + Each line in the lexicon contains "word p1 p2 p3 ...". + That is, the first field is a word and the remaining + fields are tokens. Fields are separated by space(s). + + Args: + filename: + Path to the lexicon.txt + + Returns: + A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])] + """ + ans = [] + + with open(filename, "r", encoding="utf-8") as f: + whitespace = re.compile("[ \t]+") + for line in f: + a = whitespace.split(line.strip(" \t\r\n")) + if len(a) == 0: + continue + + if len(a) < 2: + logging.info(f"Found bad line {line} in lexicon file {filename}") + logging.info("Every line is expected to contain at least 2 fields") + continue + word = a[0] + if word == "": + logging.info(f"Found bad line {line} in lexicon file {filename}") + logging.info(" should not be a valid word") + continue + + tokens = a[1:] + ans.append((word, tokens)) + + return ans + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map( + symbols: List[str], +) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "SIL", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + sil_token = token2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [token2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + lexicon_filename = lang_dir / "lexicon.txt" + otc_token = args.otc_token + sil_token = "SIL" + sil_prob = 0.5 + + lexicon = read_lexicon(lexicon_filename) + tokens = get_tokens(lexicon) + words = get_words(lexicon) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + lexicon.append((otc_token, [otc_token])) + tokens.append(otc_token) + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") + + assert "" not in tokens + tokens = [""] + tokens + + assert "" not in words + assert "#0" not in words + assert "" not in words + assert "" not in words + + words = [""] + words + [otc_token, "#0", "", ""] + + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) + + write_mapping(lang_dir / "tokens.txt", token2id) + write_mapping(lang_dir / "words.txt", word2id) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/prepare.sh b/egs/librispeech/WSASR/prepare.sh index 0d2a67259..c242bcbb0 100755 --- a/egs/librispeech/WSASR/prepare.sh +++ b/egs/librispeech/WSASR/prepare.sh @@ -30,7 +30,8 @@ stop_stage=100 # - librispeech-lm-norm.txt.gz # otc_token="" -feature_type="ssl" +# ssl or fbank +feature_type="fbank" dl_dir=$PWD/download manifests_dir="data/manifests" @@ -40,9 +41,6 @@ lm_dir="data/lm" perturb_speed=false -# ssl or fbank - -. ./cmd.sh . shared/parse_options.sh || exit 1 # vocab size for sentence piece models. @@ -192,7 +190,23 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare G" + log "Stage 5: Prepare phone based lang" + lang_dir="data/lang_phone" + mkdir -p ${lang_dir} + + if [ ! -f $lang_dir/lexicon.txt ]; then + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_otc_lang.py --lang-dir $lang_dir + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare G" # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm @@ -216,18 +230,30 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi fi -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Compile HLG" +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compile HLG" # Note If ./local/compile_hlg.py throws OOM, # please switch to the following command # # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone - for vocab_size in ${vocab_sizes[@]}; do - bpe_lang_dir="data/lang_bpe_${vocab_size}" + lang_dir="data/lang_bpe_${vocab_size}" echo "LM DIR: ${lm_dir}" ./local/compile_hlg.py \ --lm-dir "${lm_dir}" \ --lang-dir "${bpe_lang_dir}" done fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 7: Compile HLG" + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + lang_dir="data/lang_phone" + echo "LM DIR: ${lm_dir}" + ./local/compile_hlg.py \ + --lm-dir "${lm_dir}" \ + --lang-dir "${lang_dir}" +fi diff --git a/egs/librispeech/WSASR/shared b/egs/librispeech/WSASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/librispeech/WSASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/icefall/otc_phone_graph_compiler.py b/icefall/otc_phone_graph_compiler.py new file mode 100644 index 000000000..bebdffe0c --- /dev/null +++ b/icefall/otc_phone_graph_compiler.py @@ -0,0 +1,232 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import List, Union + +import k2 +import torch + +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +class OtcPhoneTrainingGraphCompiler(object): + def __init__( + self, + lexicon: Lexicon, + otc_token: str, + oov: str = "", + device: Union[str, torch.device] = "cpu", + initial_bypass_weight: float = 0.0, + initial_self_loop_weight: float = 0.0, + bypass_weight_decay: float = 0.0, + self_loop_weight_decay: float = 0.0, + ) -> None: + """ + Args: + lexicon: + It is built from `data/lang/lexicon.txt`. + otc_token: + The special token in OTC that represent all non-blank tokens + device: + It indicates CPU or CUDA. + """ + self.device = device + L_inv = lexicon.L_inv.to(self.device) + assert L_inv.requires_grad is False + assert oov in lexicon.word_table + + self.L_inv = k2.arc_sort(L_inv) + self.oov_id = lexicon.word_table[oov] + self.otc_id = lexicon.word_table[otc_token] + self.word_table = lexicon.word_table + + max_token_id = max(lexicon.tokens) + ctc_topo = k2.ctc_topo(max_token_id, modified=False) + self.ctc_topo = ctc_topo.to(self.device) + self.max_token_id = max_token_id + + self.initial_bypass_weight = initial_bypass_weight + self.initial_self_loop_weight = initial_self_loop_weight + self.bypass_weight_decay = bypass_weight_decay + self.self_loop_weight_decay = self_loop_weight_decay + + def get_max_token_id(self): + return self.max_token_id + + def make_arc( + self, + from_state: int, + to_state: int, + symbol: Union[str, int], + weight: float, + ): + return f"{from_state} {to_state} {symbol} {weight}" + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of word IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of word IDs. + """ + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(self.oov_id) + word_ids_list.append(word_ids) + return word_ids_list + + def compile( + self, + texts: List[str], + allow_bypass_arc: str2bool = True, + allow_self_loop_arc: str2bool = True, + bypass_weight: float = 0.0, + self_loop_weight: float = 0.0, + ) -> k2.Fsa: + """Build a OTC graph from a texts (list of words). + + Args: + texts: + A list of strings. Each string contains a sentence for an utterance. + A sentence consists of spaces separated words. An example `texts` + looks like: + ['hello icefall', 'CTC training with k2'] + allow_bypass_arc: + Whether to add bypass arc to training graph for substitution + and insertion errors (wrong or extra words in the transcript). + allow_self_loop_arc: + Whether to add self-loop arc to training graph for deletion + errors (missing words in the transcript). + bypass_weight: + Weight associated with bypass arc. + self_loop_weight: + Weight associated with self-loop arc. + + Return: + Return an FsaVec, which is the result of composing a + CTC topology with OTC FSAs constructed from the given texts. + """ + + transcript_fsa = self.convert_transcript_to_fsa( + texts, + allow_bypass_arc, + allow_self_loop_arc, + bypass_weight, + self_loop_weight, + ) + fsa_with_self_loop = k2.remove_epsilon_and_add_self_loops(transcript_fsa) + fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop) + + graph = k2.compose( + self.ctc_topo, + fsa_with_self_loop, + treat_epsilons_specially=False, + ) + assert graph.requires_grad is False + + return graph + + def convert_transcript_to_fsa( + self, + texts: List[str], + allow_bypass_arc: str2bool = True, + allow_self_loop_arc: str2bool = True, + bypass_weight: float = 0.0, + self_loop_weight: float = 0.0, + ): + + word_fsa_list = [] + for text in texts: + word_ids = [] + + for word in text.split(): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(self.oov_id) + + arcs = [] + start_state = 0 + cur_state = start_state + next_state = 1 + + for word_id in word_ids: + if allow_self_loop_arc: + self_loop_arc = self.make_arc( + cur_state, + cur_state, + self.otc_id, + self_loop_weight, + ) + arcs.append(self_loop_arc) + + arc = self.make_arc(cur_state, next_state, word_id, 0.0) + arcs.append(arc) + + if allow_bypass_arc: + bypass_arc = self.make_arc( + cur_state, + next_state, + self.otc_id, + bypass_weight, + ) + arcs.append(bypass_arc) + + cur_state = next_state + next_state += 1 + + if allow_self_loop_arc: + self_loop_arc = self.make_arc( + cur_state, + cur_state, + self.otc_id, + self_loop_weight, + ) + arcs.append(self_loop_arc) + + # Deal with final state + final_state = next_state + final_arc = self.make_arc(cur_state, final_state, -1, 0.0) + arcs.append(final_arc) + arcs.append(f"{final_state}") + sorted_arcs = sorted(arcs, key=lambda a: int(a.split()[0])) + + word_fsa = k2.Fsa.from_str("\n".join(sorted_arcs)) + word_fsa = k2.arc_sort(word_fsa) + word_fsa_list.append(word_fsa) + + word_fsa_vec = k2.create_fsa_vec(word_fsa_list).to(self.device) + word_fsa_vec_with_self_loop = k2.add_epsilon_self_loops(word_fsa_vec) + + fsa = k2.intersect( + self.L_inv, word_fsa_vec_with_self_loop, treat_epsilons_specially=False + ) + ans_fsa = fsa.invert_() + return k2.arc_sort(ans_fsa) From b49351fc39dd8def4bb53f9c1c9efa482a0ad769 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 28 Apr 2024 09:56:13 +0800 Subject: [PATCH 159/216] Update README.md for conformer-ctc (#1609) --- egs/aishell/ASR/conformer_ctc/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/conformer_ctc/README.md b/egs/aishell/ASR/conformer_ctc/README.md index 50596ee92..41637159d 100644 --- a/egs/aishell/ASR/conformer_ctc/README.md +++ b/egs/aishell/ASR/conformer_ctc/README.md @@ -1,4 +1,4 @@ Please visit - + for how to run this recipe. From 6d7c1d13a5b84f15a0ebd33a590f38a90c2bdb13 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Tue, 30 Apr 2024 11:49:20 +0800 Subject: [PATCH 160/216] update speechio whisper ft results (#1605) * update speechio whisper ft results --- egs/multi_zh-hans/ASR/RESULTS.md | 43 + egs/multi_zh-hans/ASR/prepare.sh | 4 +- egs/multi_zh-hans/ASR/whisper/decode.py | 50 +- .../ASR/whisper/multi_dataset.py | 91 +- egs/multi_zh-hans/ASR/whisper/train.py | 50 +- .../whisper_decoder_forward_monkey_patch.py | 46 + egs/speechio/ASR/RESULTS.md | 114 +- ...pformer_fusion.py => normalize_results.py} | 102 +- egs/speechio/ASR/local/speechio_norm.py | 1364 +++++++++++++++++ egs/wenetspeech/ASR/local/fix_manifest.py | 126 ++ egs/wenetspeech/ASR/prepare.sh | 9 + .../asr_datamodule.py | 4 +- egs/wenetspeech/ASR/whisper/train.py | 8 +- egs/wenetspeech/KWS/run.sh | 8 +- 14 files changed, 1812 insertions(+), 207 deletions(-) mode change 100644 => 100755 egs/multi_zh-hans/ASR/whisper/decode.py mode change 100644 => 100755 egs/multi_zh-hans/ASR/whisper/train.py create mode 100644 egs/multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py rename egs/speechio/ASR/local/{whisper_zipformer_fusion.py => normalize_results.py} (62%) mode change 100644 => 100755 create mode 100755 egs/speechio/ASR/local/speechio_norm.py create mode 100644 egs/wenetspeech/ASR/local/fix_manifest.py diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index 15e789604..a7f3bc4f7 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -1,5 +1,48 @@ ## Results +### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2 +#### Whisper +[./whisper](./whisper) + +Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search. + +| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | +|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| +| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | +| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 | + +Command for training is: +```bash +pip install -r whisper/requirements.txt + +# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54 + +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --deepspeed \ + --deepspeed_config ./whisper/ds_config_zero1.json +``` + +Command for decoding using fine-tuned models: +```bash +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper +ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --beam-size 10 --max-duration 50 +``` + +Fine-tuned models, training logs, decoding logs, tensorboard and decoding results +are available at + + + ### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall. diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index d1c7e695c..3d2a9471c 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -226,8 +226,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then log "Stage 11: Prepare WenetSpeech" if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then cd data/fbank - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . diff --git a/egs/multi_zh-hans/ASR/whisper/decode.py b/egs/multi_zh-hans/ASR/whisper/decode.py old mode 100644 new mode 100755 index 2a9c2e75d..f758f546c --- a/egs/multi_zh-hans/ASR/whisper/decode.py +++ b/egs/multi_zh-hans/ASR/whisper/decode.py @@ -57,6 +57,7 @@ from lhotse.cut import Cut from multi_dataset import MultiDataset from tn.chinese.normalizer import Normalizer from whisper.normalizers import BasicTextNormalizer +from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from zhconv import convert @@ -214,7 +215,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], + choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], help="""The model name to use. """, ) @@ -226,6 +227,13 @@ def get_parser(): help="replace whisper encoder forward method to remove input length restriction", ) + parser.add_argument( + "--use-distill-whisper", + type=str2bool, + default=False, + help="Whether to use architecture of distill whisper.", + ) + return parser @@ -307,6 +315,43 @@ def decode_dataset( Returns: Return a dict, whose key may be "beam-search". """ + + def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + if normalize == "none": + return text + elif normalize == "m2met": + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + results = [] num_cuts = 0 @@ -331,6 +376,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_text = normalize_text_alimeeting(ref_text) ref_words = ref_text.split() this_batch.append((cut_id, ref_words, hyp_words)) @@ -430,6 +476,8 @@ def main(): if params.remove_whisper_encoder_input_length_restriction: replace_whisper_encoder_forward() + if params.use_distill_whisper: + replace_whisper_decoder_forward() model = whisper.load_model(params.model_name, "cpu") if params.epoch > 0: if params.avg > 1: diff --git a/egs/multi_zh-hans/ASR/whisper/multi_dataset.py b/egs/multi_zh-hans/ASR/whisper/multi_dataset.py index b562e626b..d0054c4f7 100644 --- a/egs/multi_zh-hans/ASR/whisper/multi_dataset.py +++ b/egs/multi_zh-hans/ASR/whisper/multi_dataset.py @@ -43,7 +43,7 @@ class MultiDataset: - thchs_30_cuts_train.jsonl.gz - kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz - kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz - - wenetspeech/cuts_L.jsonl.gz + - wenetspeech/cuts_L_fixed.jsonl.gz """ self.fbank_dir = Path(fbank_dir) @@ -105,7 +105,7 @@ class MultiDataset: # WeNetSpeech logging.info("Loading WeNetSpeech in lazy mode") wenetspeech_L_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz" + self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz" ) # KeSpeech @@ -124,10 +124,10 @@ class MultiDataset: aishell_4_L_cuts, aishell_4_M_cuts, aishell_4_S_cuts, + alimeeting_cuts, stcmds_cuts, primewords_cuts, magicdata_cuts, - alimeeting_cuts, wenetspeech_L_cuts, kespeech_1_cuts, kespeech_2_cuts, @@ -138,10 +138,10 @@ class MultiDataset: len(aishell_4_L_cuts), len(aishell_4_M_cuts), len(aishell_4_S_cuts), + len(alimeeting_cuts), len(stcmds_cuts), len(primewords_cuts), len(magicdata_cuts), - len(alimeeting_cuts), len(wenetspeech_L_cuts), len(kespeech_1_cuts), len(kespeech_2_cuts), @@ -151,55 +151,13 @@ class MultiDataset: def dev_cuts(self) -> CutSet: logging.info("About to get multidataset dev cuts") - # AISHELL - logging.info("Loading Aishell DEV set in lazy mode") - aishell_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_dev.jsonl.gz" - ) - - # AISHELL-2 - logging.info("Loading Aishell-2 DEV set in lazy mode") - aishell2_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" - ) - - # Ali-Meeting - logging.info("Loading Ali-Meeting DEV set in lazy mode") - alimeeting_dev_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" - ) - - # MagicData - logging.info("Loading MagicData DEV set in lazy mode") - magicdata_dev_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" - ) - - # KeSpeech - logging.info("Loading KeSpeech DEV set in lazy mode") - kespeech_dev_phase1_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" - ) - kespeech_dev_phase2_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" - ) - # WeNetSpeech logging.info("Loading WeNetSpeech DEV set in lazy mode") wenetspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" + self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz" ) return wenetspeech_dev_cuts - # return [ - # aishell_dev_cuts, - # aishell2_dev_cuts, - # alimeeting_dev_cuts, - # magicdata_dev_cuts, - # kespeech_dev_phase1_cuts, - # kespeech_dev_phase2_cuts, - # wenetspeech_dev_cuts, - # ] def test_cuts(self) -> Dict[str, CutSet]: logging.info("About to get multidataset test cuts") @@ -267,30 +225,23 @@ class MultiDataset: self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz" ) wenetspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" + self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz" ) return { - "aishell-2_test": aishell2_test_cuts, - "aishell-4": aishell4_test_cuts, - "magicdata_test": magicdata_test_cuts, - "kespeech-asr_test": kespeech_test_cuts, + "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, + # "aishell_test": aishell_test_cuts, + # "aishell_dev": aishell_dev_cuts, + # "ali-meeting_test": alimeeting_test_cuts, + # "ali-meeting_eval": alimeeting_eval_cuts, + # "aishell-4_test": aishell4_test_cuts, + # "aishell-2_test": aishell2_test_cuts, + # "aishell-2_dev": aishell2_dev_cuts, + # "magicdata_test": magicdata_test_cuts, + # "magicdata_dev": magicdata_dev_cuts, + # "kespeech-asr_test": kespeech_test_cuts, + # "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, + # "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, + # "wenetspeech-net_test": wenetspeech_test_net_cuts, + # "wenetspeech_dev": wenetspeech_dev_cuts, } - - # return { - # "alimeeting_test": alimeeting_test_cuts, - # "alimeeting_eval": alimeeting_eval_cuts, - # "aishell_test": aishell_test_cuts, - # "aishell_dev": aishell_dev_cuts, - # "aishell-2_test": aishell2_test_cuts, - # "aishell-2_dev": aishell2_dev_cuts, - # "aishell-4": aishell4_test_cuts, - # "magicdata_test": magicdata_test_cuts, - # "magicdata_dev": magicdata_dev_cuts, - # "kespeech-asr_test": kespeech_test_cuts, - # "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, - # "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, - # "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, - # "wenetspeech-net_test": wenetspeech_test_net_cuts, - # "wenetspeech_dev": wenetspeech_dev_cuts, - # } diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py old mode 100644 new mode 100755 index 7a0781d5a..fe2d950c1 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -65,6 +65,7 @@ from torch.cuda.amp import GradScaler from torch.nn.functional import pad as pad_tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from icefall import diagnostics @@ -146,7 +147,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], + choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], help="""The model name to use. """, ) @@ -232,6 +233,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-distill-whisper", + type=str2bool, + default=False, + help="Whether to use architecture of distill whisper.", + ) + parser = deepspeed.add_config_arguments(parser) return parser @@ -441,6 +449,42 @@ def compute_loss( padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) return torch.stack([tensor for tensor in padded_tensors], dim=0) + def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + if normalize == "none": + return text + elif normalize == "m2met": + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + max_frames = params.max_duration * 1000 // params.frame_shift_ms allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) batch = filter_uneven_sized_batch(batch, allowed_max_frames) @@ -459,7 +503,7 @@ def compute_loss( texts = batch["supervisions"]["text"] # remove spaces in texts - texts = [text.replace(" ", "") for text in texts] + texts = [normalize_text_alimeeting(text) for text in texts] text_tokens_list = [ list(tokenizer.sot_sequence_including_notimestamps) @@ -759,6 +803,8 @@ def run(rank, world_size, args): logging.info("About to create model") replace_whisper_encoder_forward() + if params.use_distill_whisper: + replace_whisper_decoder_forward() model = whisper.load_model(params.model_name, "cpu") del model.alignment_heads diff --git a/egs/multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py b/egs/multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py new file mode 100644 index 000000000..c013426d4 --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py @@ -0,0 +1,46 @@ +from typing import Dict, Iterable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +import whisper +from torch import Tensor, nn +from whisper.model import LayerNorm, ResidualAttentionBlock + + +def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = ( + self.token_embedding(x) + + self.positional_embedding[offset : offset + x.shape[-1]] + ) + x = x + self.positional_embedding[offset : offset + x.shape[1]] + x = x.to(xa.dtype) + + # for block in self.blocks: + # x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + # use architecture from the distill whisper model + # see https://github.com/huggingface/distil-whisper + x = self.blocks[0](x, xa, mask=self.mask, kv_cache=kv_cache) + x = self.blocks[-1](x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits + + +def replace_whisper_decoder_forward(): + """ + This function monkey patches the forward method of the whisper encoder. + To be called before the model is loaded, it changes whisper to process audio with any length < 30s. + """ + whisper.model.TextDecoder.forward = forward diff --git a/egs/speechio/ASR/RESULTS.md b/egs/speechio/ASR/RESULTS.md index 07649e383..f1273d41e 100644 --- a/egs/speechio/ASR/RESULTS.md +++ b/egs/speechio/ASR/RESULTS.md @@ -2,50 +2,81 @@ ### SpeechIO Test Set Decoding Results -##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper). - -| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion | -|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------| -| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 | -| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 | -| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 | -| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 | -| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 | -| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 | -| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 | -| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 | -| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 | -| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 | -| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 | -| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 | -| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 | -| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 | -| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 | -| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 | -| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 | -| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 | -| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 | -| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 | -| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 | -| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 | -| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 | -| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 | -| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 | -| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 | -| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 | -| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 | +#### **Unlocked** SpeechIO test sets (ZH00001 ~ ZH00026) +| Rank 排名 | Model 模型 | CER 字错误率 | Date 时间 | +| --- | --- | --- | --- | +| 1 | ximalaya_api_zh | 1.72% | 2023.12 | +| 2 | aliyun_ftasr_api_zh | 1.85% | 2023.12 | +| 3 | microsoft_batch_zh | 2.40% | 2023.12 | +| 4 | bilibili_api_zh | 2.90% | 2023.09 | +| 5 | tencent_api_zh | 3.18% | 2023.12 | +| 6 | iflytek_lfasr_api_zh | 3.32% | 2023.12 | +| 7 | aispeech_api_zh | 3.62% | 2023.12 | +| 8 | **whisper-large-ft-v1** | **4.32%** | 2024.04 | +| 9 | **whisper-large-ft-v0.5** | **4.60%** | 2024.04 | +| 10 | **zipformer (70Mb)** | **6.17%** | 2023.10 | +| 11 | **whisper-large-ft-v0** | **6.34%** | 2023.03 | +| 12 | baidu_pro_api_zh | 7.29% | 2023.12 | + +Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67) + +
Detail all models

+ +| Model | Training Set | Note | +|----------------------------------------------------------------------------------------------------------|---------------|-----------------------------------------------------| +|[zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24)| multi-hans-zh | decoding with transducer head and blank penalty 2.0 | +|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/tree/main/exp_large_v2)| wenetspeech | greedy_search, 3 epochs| +|[whisper-large-ft-v0.5](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/blob/main/epoch-2-avg-5.pt)| wenetspeech(updated) | [wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy_search, 2 epochs | +|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 3 epochs| + +

+ + +
Detail all results (字错误率 CER %)

+ +| Test Set ID | 测试场景&内容领域 | bilibili_api_zh (2023.09) | whisper-large-ft-v0 | whisper-large-ft-v1 | zipformer | +|----------------------|-------------------------------|-----------------|---------|-----------|-----------| +| Avg (01-26) | | 2.9 | 6.34 | 4.32 | 6.17 | +| SPEECHIO_ASR_ZH00001 | 新闻联播 | 0.54 | 1.42 | 1.09 | 1.37 | +| SPEECHIO_ASR_ZH00002 | 访谈 鲁豫有约 | 2.78 | 4.76 | 3.21 | 4.67 | +| SPEECHIO_ASR_ZH00003 | 电视节目 天下足球 | 0.81 | 2.17 | 1.70 | 2.71 | +| SPEECHIO_ASR_ZH00004 | 场馆演讲 罗振宇跨年 | 1.48 | 2.53 | 1.86 | 2.54 | +| SPEECHIO_ASR_ZH00005 | 在线教育 李永乐 科普 | 1.47 | 4.27 | 1.95 | 3.12 | +| SPEECHIO_ASR_ZH00006 | 直播 王者荣耀 张大仙&骚白 | 5.85 | 12.55 | 9.46 | 12.86 | +| SPEECHIO_ASR_ZH00007 | 直播 带货 李佳琪&薇娅 | 6.19 | 13.38 | 10.38 | 14.58 | +| SPEECHIO_ASR_ZH00008 | 线下培训 老罗语录 | 3.68 | 9.56 | 6.9 | 9.05 | +| SPEECHIO_ASR_ZH00009 | 播客 故事FM | 3.18 | 5.66 | 3.78 | 5.4 | +| SPEECHIO_ASR_ZH00010 | 播客 创业内幕 | 3.51 | 7.84 | 4.36 | 6.4 | +| SPEECHIO_ASR_ZH00011 | 在线教育 罗翔 刑法法考 | 1.77 | 3.22 | 2.40 | 3.12 | +| SPEECHIO_ASR_ZH00012 | 在线教育 张雪峰 考研 | 2.11 | 5.98 | 3.03 | 4.41 | +| SPEECHIO_ASR_ZH00013 | 短视频 影剪 谷阿莫&牛叔说电影 | 2.97 | 5.91 | 3.72 | 6.56 | +| SPEECHIO_ASR_ZH00014 | 短视频 美式&烹饪 | 3.56 | 6.03 | 4.92 | 8.14 | +| SPEECHIO_ASR_ZH00015 | 评书 单田芳 白眉大侠 | 4.72 | 8.77 | 7.92 | 9.1 | +| SPEECHIO_ASR_ZH00016 | 相声 德云社专场 | 3.01 | 5.24 | 4.15 | 5.59 | +| SPEECHIO_ASR_ZH00017 | 脱口秀 吐槽大会 | 2.93 | 7.05 | 3.04 | 5.17 | +| SPEECHIO_ASR_ZH00018 | 少儿卡通 小猪佩奇&熊出没 | 1.98 | 3.53 | 3.27 | 4.15 | +| SPEECHIO_ASR_ZH00019 | 体育赛事解说 NBA比赛 | 2.32 | 6.89 | 4.39 | 6.66 | +| SPEECHIO_ASR_ZH00020 | 纪录片 篮球人物 | 1.51 | 4.16 | 3.04 | 4.2 | +| SPEECHIO_ASR_ZH00021 | 短视频 汽车之家 汽车评测 | 1.75 | 4.77 | 2.69 | 4.17 | +| SPEECHIO_ASR_ZH00022 | 短视频 小艾大叔 豪宅带看 | 3.29 | 6.35 | 5.44 | 6.72 | +| SPEECHIO_ASR_ZH00023 | 短视频 开箱视频 Zeal&无聊开箱 | 2.18 | 8.99 | 4.08 | 7.94 | +| SPEECHIO_ASR_ZH00024 | 短视频 付老师 农业种植 | 4.80 | 10.81 | 6.06 | 8.64 | +| SPEECHIO_ASR_ZH00025 | 线下课堂 石国鹏 古希腊哲学 | 3.32 | 8.41 | 5.39 | 8.54 | +| SPEECHIO_ASR_ZH00026 | 广播电台节目 张震鬼故事 | 3.70 | 4.52 | 4.06 | 4.67 | +

+ Command for decoding using fine-tuned whisper: ```bash git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper -ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt +git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper +ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2_wenetspeech \ + --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ --epoch 999 --avg 1 \ --start-index 0 --end-index 26 \ @@ -76,17 +107,6 @@ mv words.txt ./data/lang_bpe_2000/ --manifest-dir data/fbank_kaldi \ --decoding-method greedy_search ``` -Command for fusion the above decoding results from whisper and zipformer: -```bash -python local/whisper_zipformer_fusion.py \ - --whisper-log-dir ./whisper/exp_large_v2_wenetspeech \ - --zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \ - --output-log-dir ./results_fusion - -``` - -See why the fusion helps [here](./local/whisper_zipformer_fusion.py). SpeechIO fbank features, decoding scripts, logs, and decoding results -are available at - +are available at [part1]() and [part2](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1). diff --git a/egs/speechio/ASR/local/whisper_zipformer_fusion.py b/egs/speechio/ASR/local/normalize_results.py old mode 100644 new mode 100755 similarity index 62% rename from egs/speechio/ASR/local/whisper_zipformer_fusion.py rename to egs/speechio/ASR/local/normalize_results.py index 04c5e75f0..14eb1bb2f --- a/egs/speechio/ASR/local/whisper_zipformer_fusion.py +++ b/egs/speechio/ASR/local/normalize_results.py @@ -21,7 +21,7 @@ Since whisper model is more likely to make deletion errors and zipformer model i we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors. Usage: - python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion + python whisper_zipformer_fusion.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm """ import argparse @@ -29,6 +29,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import kaldialign +from speechio_norm import TextNorm from icefall.utils import store_transcripts, write_error_stats @@ -38,31 +39,36 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( - "--whisper-log-dir", + "--model-log-dir", type=str, default="./recogs_whisper", help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt", ) - parser.add_argument( - "--zipformer-log-dir", - type=str, - default="./recogs_zipformer", - help="The directory to store the zipformer logs", - ) parser.add_argument( "--output-log-dir", type=str, - default="./results_fusion", - help="The directory to store the fusion logs", + default="./results_whisper_norm", + help="The directory to store the normalized whisper logs", ) return parser -def save_results( +def save_results_with_speechio_text_norm( res_dir: Path, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): + normalizer = TextNorm() + # normlize items in results_dict + for key, results in results_dict.items(): + results_norm = [] + for item in results: + wav_name, ref, hyp = item + ref = normalizer(ref) + hyp = normalizer(hyp) + results_norm.append((wav_name, ref, hyp)) + results_dict[key] = results_norm + test_set_wers = dict() suffix = "epoch-999-avg-1" @@ -120,11 +126,9 @@ def extract_hyp_ref_wavname(filename): return hyps, refs, wav_name -def get_pair_filenames( +def get_filenames( whisper_log_dir, - zipformer_log_dir, whisper_suffix="beam-search-epoch-999-avg-1", - zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0", ): results = [] start_index, end_index = 0, 26 @@ -134,80 +138,24 @@ def get_pair_filenames( dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") for partition in dataset_parts: whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt" - zipformer_filename = ( - f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt" - ) - results.append((whisper_filename, zipformer_filename)) + results.append(whisper_filename) return results -def fusion_hyps_trust_substituion_insertion( - hyps_whisper, hyps_zipformer, refs, ERR="*" -): - """ - alignment example: - [('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')] - left is whisper, right is zipformer - for whisper substitution, use left - for whisper insertion, use left - for whisper deletion, use right - """ - hyps_fusion = [] - for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs): - ali = kaldialign.align(hyp_w, hyp_z, ERR) - hyp_f = "" - for a in ali: - if a[0] == ERR: - hyp_f += a[1] - else: - hyp_f += a[0] - hyps_fusion.append(hyp_f) - return hyps_fusion - - -def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"): - """ - alignment example: - [('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')] - left is whisper, right is zipformer - for whisper substitution, use left - for whisper insertion, use right - for whisper deletion, use right - """ - hyps_fusion = [] - for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs): - ali = kaldialign.align(hyp_w, hyp_z, ERR) - hyp_f = "" - for a in ali: - if a[0] == ERR: - hyp_f += a[1] - elif a[1] == ERR: - pass - else: - hyp_f += a[0] - hyps_fusion.append(hyp_f) - return hyps_fusion - - def main(): parser = get_parser() args = parser.parse_args() # mkdir output_log_dir Path(args.output_log_dir).mkdir(parents=True, exist_ok=True) - pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir) - for pair in pair_logs: - hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0]) - hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1]) + filenames = get_filenames(args.model_log_dir) + for filename in filenames: + hyps, refs, wav_name = extract_hyp_ref_wavname(filename) + partition_name = filename.split("/")[-1].split("-")[1] - hyps_fusion = fusion_hyps_trust_substituion_insertion( - hyps_whisper, hyps_zipformer, refs - ) - - partition_name = pair[0].split("/")[-1].split("-")[1] - save_results( + save_results_with_speechio_text_norm( Path(args.output_log_dir), partition_name, - {"fusion": list(zip(wav_name, refs, hyps_fusion))}, + {"norm": list(zip(wav_name, refs, hyps))}, ) print(f"Processed {partition_name}") diff --git a/egs/speechio/ASR/local/speechio_norm.py b/egs/speechio/ASR/local/speechio_norm.py new file mode 100755 index 000000000..6f3cd55b0 --- /dev/null +++ b/egs/speechio/ASR/local/speechio_norm.py @@ -0,0 +1,1364 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 - 2022 Jiayu DU +# +# requirements: +# - python 3.X +# notes: python 2.X WILL fail or produce misleading results + +import argparse +import csv +import os +import re +import string +import sys + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = "零一二三四五六七八九" +BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" +BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" + +ZERO_ALT = "〇" +ONE_ALT = "幺" +TWO_ALTS = ["两", "兩"] + +POSITIVE = ["正", "正"] +NEGATIVE = ["负", "負"] +POINT = ["点", "點"] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +FILLER_CHARS = ["呃", "啊"] + +ER_WHITELIST = ( + "(儿女|儿子|儿孙|女儿|儿媳|妻儿|" + "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|" + "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|" + "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)" +) +ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST) + +# 中文数字系统类型 +NUMBERING_TYPES = ["low", "mid", "high"] + +CURRENCY_NAMES = ( + "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" + "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" +) +CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" +COM_QUANTIFIERS = ( + "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" + "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" + "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" + "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" + "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)" +) + + +# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CN_PUNCS_STOP = "!?。。" +CN_PUNCS_NONSTOP = ( + ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-" +) +CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP + +PUNCS = CN_PUNCS + string.punctuation +PUNCS_TRANSFORM = str.maketrans(PUNCS, " " * len(PUNCS), "") # replace puncs with space + + +# https://zh.wikipedia.org/wiki/全行和半行 +QJ2BJ = { + " ": " ", + "!": "!", + """: '"', + "#": "#", + "$": "$", + "%": "%", + "&": "&", + "'": "'", + "(": "(", + ")": ")", + "*": "*", + "+": "+", + ",": ",", + "-": "-", + ".": ".", + "/": "/", + "0": "0", + "1": "1", + "2": "2", + "3": "3", + "4": "4", + "5": "5", + "6": "6", + "7": "7", + "8": "8", + "9": "9", + ":": ":", + ";": ";", + "<": "<", + "=": "=", + ">": ">", + "?": "?", + "@": "@", + "A": "A", + "B": "B", + "C": "C", + "D": "D", + "E": "E", + "F": "F", + "G": "G", + "H": "H", + "I": "I", + "J": "J", + "K": "K", + "L": "L", + "M": "M", + "N": "N", + "O": "O", + "P": "P", + "Q": "Q", + "R": "R", + "S": "S", + "T": "T", + "U": "U", + "V": "V", + "W": "W", + "X": "X", + "Y": "Y", + "Z": "Z", + "[": "[", + "\": "\\", + "]": "]", + "^": "^", + "_": "_", + "`": "`", + "a": "a", + "b": "b", + "c": "c", + "d": "d", + "e": "e", + "f": "f", + "g": "g", + "h": "h", + "i": "i", + "j": "j", + "k": "k", + "l": "l", + "m": "m", + "n": "n", + "o": "o", + "p": "p", + "q": "q", + "r": "r", + "s": "s", + "t": "t", + "u": "u", + "v": "v", + "w": "w", + "x": "x", + "y": "y", + "z": "z", + "{": "{", + "|": "|", + "}": "}", + "~": "~", +} +QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "") + + +# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources: +# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total +CN_CHARS_COMMON = ( + "一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举" + "乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互" + "亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从" + "仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优" + "伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚" + "佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣" + "侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯" + "俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌" + "偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚" + "僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六" + "兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况" + "冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈" + "刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐" + "剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼" + "劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹" + "区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵" + "卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔" + "叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊" + "同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆" + "呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎" + "咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌" + "响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛" + "唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴" + "啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌" + "嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡" + "嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓" + "嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢" + "圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥" + "坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩" + "垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基" + "埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填" + "塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复" + "夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖" + "套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮" + "妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈" + "娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻" + "婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱" + "嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽" + "宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾" + "宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝" + "尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山" + "屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃" + "峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧" + "崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉" + "巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡" + "带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐" + "庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷" + "建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖" + "彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循" + "徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀" + "态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓" + "恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟" + "悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭" + "惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥" + "慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我" + "戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔" + "托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡" + "抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥" + "拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫" + "振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎" + "掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭" + "揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘" + "摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢" + "擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整" + "敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗" + "旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝" + "星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡" + "晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜" + "曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽" + "杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅" + "枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔" + "柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩" + "株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯" + "桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘" + "棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂" + "楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱" + "榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞" + "橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正" + "此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒" + "毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮" + "氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽" + "汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽" + "沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼" + "泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿" + "流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸" + "浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅" + "淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟" + "渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆" + "溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕" + "滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶" + "漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧" + "澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵" + "灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈" + "烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯" + "焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵" + "熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍" + "牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷" + "犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎" + "猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎" + "玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊" + "珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊" + "琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙" + "瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱" + "璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯" + "田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐" + "疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒" + "痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩" + "瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙" + "皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾" + "省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦" + "睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知" + "矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮" + "砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍" + "碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷" + "磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭" + "祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣" + "秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗" + "穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立" + "竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯" + "笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅" + "箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾" + "簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮" + "粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢" + "縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁" + "绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩" + "绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕" + "编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐" + "网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲" + "羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者" + "耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋" + "职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸" + "肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳" + "胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒" + "腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻" + "臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般" + "舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊" + "芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈" + "苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆" + "茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐" + "荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛" + "莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥" + "菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著" + "葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺" + "蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷" + "蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸" + "薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱" + "虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆" + "蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗" + "蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃" + "螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡" + "蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒" + "袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂" + "褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉" + "觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认" + "讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词" + "诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请" + "诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡" + "谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹" + "豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼" + "贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤" + "赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑" + "跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮" + "踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇" + "躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较" + "辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄" + "迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆" + "选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒" + "道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱" + "邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴" + "郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝" + "酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭" + "醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒" + "钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼" + "钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨" + "铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐" + "锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹" + "锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣" + "镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼" + "闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶" + "阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈" + "隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳" + "零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰" + "靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵" + "韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额" + "颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰" + "饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥" + "馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑" + "骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高" + "髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾" + "鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨" + "鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓" + "鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶" + "鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟" + "鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄" + "黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷" + "鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃" + "㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡" + "䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽" + "𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯" + "𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟" + "𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟" + "𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟" + "𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓" +) +CN_CHARS_EXT = "吶诶屌囧飚屄" + +CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT +IN_CH_CHARS = {c: True for c in CN_CHARS} + +EN_CHARS = string.ascii_letters + string.digits +IN_EN_CHARS = {c: True for c in EN_CHARS} + +VALID_CHARS = CN_CHARS + EN_CHARS + " " +IN_VALID_CHARS = {c: True for c in VALID_CHARS} + +# ================================================================================ # +# basic class +# ================================================================================ # + + +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + # self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return "10^{}".format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit( + power=index + 1, + simplified=value[0], + traditional=value[1], + big_s=value[1], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + else: + raise ValueError( + "Counting type should be in {0} ({1} provided).".format( + NUMBERING_TYPES, numbering_type + ) + ) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__( + self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None + ): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + larger_units = [ + CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) + ] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + smaller_units = [ + CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) + ] + # digis + chinese_digis = zip( + CHINESE_DIGIS, + CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, + BIG_CHINESE_DIGIS_TRADITIONAL, + ) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) + point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [ + d.traditional, + d.simplified, + d.big_s, + d.big_t, + d.alt_s, + d.alt_t, + ]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, "" + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], [ + get_symbol(c, system) for c in dec_string + ] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance( + integer_symbols[-2], CNU + ): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None) + ) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if ( + isinstance(result[-i - 1], CNU) + and result[-i - 1].power < current_unit.power + ): + result[-i - 1] = CNU( + result[-i - 1].power + current_unit.power, + None, + None, + None, + None, + ) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = "".join([str(d.value) for d in dec_part]) + if dec_part: + return "{0}.{1}".format(int_str, dec_str) + else: + return int_str + + +def num2chn( + number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True, +): + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip("0") + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next( + u for u in reversed(system.units) if u.power < len(striped_string) + ) + result_string = value_string[: -result_unit.power] + return ( + get_value(result_string) + + [result_unit] + + get_value(striped_string[-result_unit.power :]) + ) + + system = create_system(numbering_type) + + int_dec = number_string.split(".") + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string) + ) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND( + 2, + system.digits[2].alt_s, + system.digits[2].alt_t, + system.digits[2].big_s, + system.digits[2].big_t, + ) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = ( + result_symbols[i + 1] if i < len(result_symbols) - 1 else None + ) + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance( + previous_symbol, (CNU, type(None)) + ): + if next_symbol.power != 1 and ( + (previous_symbol is None) or (previous_symbol.power != 1) + ): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = "big_" + if traditional: + attr_name += "t" + else: + attr_name += "s" + else: + if traditional: + attr_name = "traditional" + else: + attr_name = "simplified" + + result = "".join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s + ) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s + ) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if ( + len(result) >= 2 + and result[1] + in [ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], + ] + and result[0] + in [ + CHINESE_DIGIS[1], + BIG_CHINESE_DIGIS_SIMPLIFIED[1], + BIG_CHINESE_DIGIS_TRADITIONAL[1], + ] + ): + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split("-") + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + else: + sp_parts = self.telephone.strip("+").split() + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split("分之") + return chn2num(numerator) + "/" + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split("/") + return num2chn(denominator) + "分之" + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split("年", 1) + year = Digit(digit=year).digit2chntext() + "年" + except ValueError: + other = date + year = "" + if other: + try: + month, day = other.strip().split("月", 1) + month = Cardinal(cardinal=month).cardinal2chntext() + "月" + except ValueError: + day = date + month = "" + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = "" + day = "" + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() + ) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip("百分之")) + "%" + + def percentage2chntext(self): + return "百分之" + num2chn(self.percentage.strip().strip("%")) + + +def normalize_nsw(raw_text): + text = "^" + raw_text + "$" + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile( + r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace( + matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 + ) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace( + matcher[0], + TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), + 1, + ) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace( + matcher, Fraction(fraction=matcher).fraction2chntext(), 1 + ) + + # 规范化百分数 + text = text.replace("%", "%") + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace( + matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1 + ) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + # restore P2P, O2O, B2C, B2B etc + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) + + return text.lstrip("^").rstrip("$") + + +def remove_erhua(text): + """ + 去除儿化音词中的儿: + 他女儿在那边儿 -> 他女儿在那边 + """ + + new_str = "" + while re.search("儿", text): + a = re.search("儿", text).span() + remove_er_flag = 0 + + if ER_WHITELIST_PATTERN.search(text): + b = ER_WHITELIST_PATTERN.search(text).span() + if b[0] <= a[0]: + remove_er_flag = 1 + + if remove_er_flag == 0: + new_str = new_str + text[0 : a[0]] + text = text[a[1] :] + else: + new_str = new_str + text[0 : b[1]] + text = text[b[1] :] + + text = new_str + text + return text + + +def remove_space(text): + tokens = text.split() + new = [] + for k, t in enumerate(tokens): + if k != 0: + if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]): + new.append(" ") + new.append(t) + return "".join(new) + + +class TextNorm: + def __init__( + self, + to_banjiao: bool = True, + to_upper: bool = True, + to_lower: bool = False, + remove_fillers: bool = True, + remove_erhua: bool = True, + check_chars: bool = False, + remove_space: bool = False, + cc_mode: str = "", + ): + self.to_banjiao = to_banjiao + self.to_upper = to_upper + self.to_lower = to_lower + self.remove_fillers = remove_fillers + self.remove_erhua = remove_erhua + self.check_chars = check_chars + self.remove_space = remove_space + + self.cc = None + if cc_mode: + from opencc import OpenCC # Open Chinese Convert: pip install opencc + + self.cc = OpenCC(cc_mode) + + def __call__(self, text): + if self.cc: + text = self.cc.convert(text) + + if self.to_banjiao: + text = text.translate(QJ2BJ_TRANSFORM) + + if self.to_upper: + text = text.upper() + + if self.to_lower: + text = text.lower() + + if self.remove_fillers: + for c in FILLER_CHARS: + text = text.replace(c, "") + + if self.remove_erhua: + text = remove_erhua(text) + + text = normalize_nsw(text) + + text = text.translate(PUNCS_TRANSFORM) + + if self.check_chars: + for c in text: + if not IN_VALID_CHARS.get(c): + print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr) + return "" + + if self.remove_space: + text = remove_space(text) + + return text + + +if __name__ == "__main__": + p = argparse.ArgumentParser() + + # normalizer options + p.add_argument( + "--to_banjiao", action="store_true", help="convert quanjiao chars to banjiao" + ) + p.add_argument("--to_upper", action="store_true", help="convert to upper case") + p.add_argument("--to_lower", action="store_true", help="convert to lower case") + p.add_argument( + "--remove_fillers", + action="store_true", + help='remove filler chars such as "呃, 啊"', + ) + p.add_argument( + "--remove_erhua", + action="store_true", + help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"', + ) + p.add_argument( + "--check_chars", + action="store_true", + help="skip sentences containing illegal chars", + ) + p.add_argument("--remove_space", action="store_true", help="remove whitespace") + p.add_argument( + "--cc_mode", + choices=["", "t2s", "s2t"], + default="", + help="convert between traditional to simplified", + ) + + # I/O options + p.add_argument( + "--log_interval", + type=int, + default=10000, + help="log interval in number of processed lines", + ) + p.add_argument( + "--has_key", + action="store_true", + help="will be deprecated, set --format ark instead", + ) + p.add_argument( + "--format", + type=str, + choices=["txt", "ark", "tsv"], + default="txt", + help="input format", + ) + p.add_argument("ifile", help="input filename, assume utf-8 encoding") + p.add_argument("ofile", help="output filename") + + args = p.parse_args() + + if args.has_key: + args.format = "ark" + + normalizer = TextNorm( + to_banjiao=args.to_banjiao, + to_upper=args.to_upper, + to_lower=args.to_lower, + remove_fillers=args.remove_fillers, + remove_erhua=args.remove_erhua, + check_chars=args.check_chars, + remove_space=args.remove_space, + cc_mode=args.cc_mode, + ) + + ndone = 0 + with open(args.ifile, "r", encoding="utf8") as istream, open( + args.ofile, "w+", encoding="utf8" + ) as ostream: + if args.format == "tsv": + reader = csv.DictReader(istream, delimiter="\t") + assert "TEXT" in reader.fieldnames + print("\t".join(reader.fieldnames), file=ostream) + + for item in reader: + text = item["TEXT"] + + if text: + text = normalizer(text) + + if text: + item["TEXT"] = text + print("\t".join([item[f] for f in reader.fieldnames]), file=ostream) + + ndone += 1 + if ndone % args.log_interval == 0: + print( + f"text norm: {ndone} lines done.", file=sys.stderr, flush=True + ) + else: + for line in istream: + key, text = "", "" + if args.format == "ark": # KALDI archive, line format: "key text" + cols = line.strip().split(maxsplit=1) + key, text = cols[0], cols[1] if len(cols) == 2 else "" + else: + text = line.strip() + + if text: + text = normalizer(text) + + if text: + if args.format == "ark": + print(key + "\t" + text, file=ostream) + else: + print(text, file=ostream) + + ndone += 1 + if ndone % args.log_interval == 0: + print( + f"text norm: {ndone} lines done.", file=sys.stderr, flush=True + ) + print(f"text norm: {ndone} lines done in total.", file=sys.stderr, flush=True) diff --git a/egs/wenetspeech/ASR/local/fix_manifest.py b/egs/wenetspeech/ASR/local/fix_manifest.py new file mode 100644 index 000000000..b2632bd52 --- /dev/null +++ b/egs/wenetspeech/ASR/local/fix_manifest.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# Copyright 2024 author: Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging + +from lhotse import CutSet, load_manifest_lazy + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--fixed-transcript-path", + type=str, + default="data/fbank/text.fix", + help=""" + See https://github.com/wenet-e2e/WenetSpeech/discussions/54 + wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix + """, + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/fbank/", + help="Directory to store the manifest files", + ) + + parser.add_argument( + "--training-subset", + type=str, + default="L", + help="The training subset for wenetspeech.", + ) + + return parser + + +def load_fixed_text(fixed_text_path): + """ + fixed text format + X0000016287_92761015_S00001 我是徐涛 + X0000016287_92761015_S00002 狄更斯的PICK WEEK PAPERS斯 + load into a dict + """ + fixed_text_dict = {} + with open(fixed_text_path, "r") as f: + for line in f: + cut_id, text = line.strip().split(" ", 1) + fixed_text_dict[cut_id] = text + return fixed_text_dict + + +def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path): + with CutSet.open_writer(fixed_manifest_path) as manifest_writer: + fixed_item = 0 + for i, cut in enumerate(manifest): + if i % 10000 == 0: + logging.info(f"Processing cut {i}, fixed {fixed_item}") + cut_id_orgin = cut.id + if cut_id_orgin.endswith("_sp0.9"): + cut_id = cut_id_orgin[:-6] + elif cut_id_orgin.endswith("_sp1.1"): + cut_id = cut_id_orgin[:-6] + else: + cut_id = cut_id_orgin + if cut_id in fixed_text_dict: + assert ( + len(cut.supervisions) == 1 + ), f"cut {cut_id} has {len(cut.supervisions)} supervisions" + if cut.supervisions[0].text != fixed_text_dict[cut_id]: + logging.info( + f"Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}" + ) + cut.supervisions[0].text = fixed_text_dict[cut_id] + fixed_item += 1 + manifest_writer.write(cut) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + fixed_text_path = args.manifest_dir + "text.fix" + fixed_text_dict = load_fixed_text(fixed_text_path) + logging.info(f"Loaded {len(fixed_text_dict)} fixed texts") + + dev_manifest_path = args.manifest_dir + "cuts_DEV.jsonl.gz" + fixed_dev_manifest_path = args.manifest_dir + "cuts_DEV_fixed.jsonl.gz" + logging.info(f"Loading dev manifest from {dev_manifest_path}") + cuts_dev_manifest = load_manifest_lazy(dev_manifest_path) + fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path) + logging.info(f"Fixed dev manifest saved to {fixed_dev_manifest_path}") + + manifest_path = args.manifest_dir + f"cuts_{args.training_subset}.jsonl.gz" + fixed_manifest_path = ( + args.manifest_dir + f"cuts_{args.training_subset}_fixed.jsonl.gz" + ) + logging.info(f"Loading manifest from {manifest_path}") + cuts_manifest = load_manifest_lazy(manifest_path) + fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path) + logging.info(f"Fixed training manifest saved to {fixed_manifest_path}") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index e3e28bd24..45912985b 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -416,3 +416,12 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then python ./local/compile_lg.py --lang-dir $lang_dir done fi + +if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then + log "Stage 23: Modify transcript according to fixed results" + # See https://github.com/wenet-e2e/WenetSpeech/discussions/54 + wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix -O data/fbank/text.fix + python local/fix_manifest.py \ + --fixed-transcript-path data/fbank/text.fix \ + --training-subset L +fi diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 58da1d68c..8b35187b1 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -390,14 +390,14 @@ class WenetSpeechAsrDataModule: def train_cuts(self) -> CutSet: logging.info("About to get train cuts") cuts_train = load_manifest_lazy( - self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz" + self.args.manifest_dir / f"cuts_{self.args.training_subset}_fixed.jsonl.gz" ) return cuts_train @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV_fixed.jsonl.gz") @lru_cache() def test_net_cuts(self) -> List[CutSet]: diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py index 493f2728a..4e55fd6a8 100644 --- a/egs/wenetspeech/ASR/whisper/train.py +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -38,6 +38,7 @@ torchrun --nproc_per_node 8 ./whisper/train.py \ import argparse import copy import logging +import os import random import warnings from pathlib import Path @@ -145,7 +146,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], + choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], help="""The model name to use. """, ) @@ -616,7 +617,9 @@ def train_one_epoch( f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", ) - + os.system( + f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" + ) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -893,6 +896,7 @@ def run(rank, world_size, args): f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", tag=f"epoch-{params.cur_epoch}", ) + os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") else: save_checkpoint( params=params, diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 8455cc5be..232ee039a 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -25,7 +25,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "You need to run the prepare.sh first." exit -1 fi - + python ./zipformer/train.py \ --world-size 4 \ --exp-dir zipformer/exp \ @@ -105,11 +105,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 \ --causal 1 -fi +fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 2: Finetune the model" - + # The following configuration of lr schedule should work well # You may also tune the following parameters to adjust learning rate schedule base_lr=0.0005 @@ -201,4 +201,4 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 \ --causal 1 -fi +fi From c08fe486038440c08d99562858a951ddbdf0aa7b Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 4 May 2024 11:42:23 +0800 Subject: [PATCH 161/216] add force=True to logging.basicConfig (#1613) --- icefall/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/icefall/utils.py b/icefall/utils.py index ec6aee6d0..57467cb56 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -163,6 +163,7 @@ def setup_logger( format=formatter, level=level, filemode="w", + force=True, ) if use_console: console = logging.StreamHandler() From 4e97b19b6355a4c3afb50f9270f74d96b7988f42 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 6 May 2024 13:00:27 +0800 Subject: [PATCH 162/216] Remove duplicate logging initialization logic in utils.py (#1617) --- icefall/utils.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index 57467cb56..0c509fc17 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -110,13 +110,6 @@ def str2bool(v): raise argparse.ArgumentTypeError("Boolean value expected.") -def clear_log_handlers(): - logger = logging.getLogger() - handlers = logger.handlers[:] - for handler in handlers: - logger.removeHandler(handler) - - def setup_logger( log_filename: Pathlike, log_level: str = "info", @@ -133,8 +126,6 @@ def setup_logger( use_console: True to also print logs to console. """ - clear_log_handlers() - now = datetime.now() date_time = now.strftime("%Y-%m-%d-%H-%M-%S") if dist.is_available() and dist.is_initialized(): From 9d570870cff167b8b963ed1f73dae0d685ec53cf Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 7 May 2024 21:37:55 +0800 Subject: [PATCH 163/216] Update asr_datamodule.py (#1619) --- egs/yesno/ASR/tdnn/asr_datamodule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index b9ce8fb4e..99f2a6d08 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -181,7 +181,7 @@ class YesNoAsrDataModule(DataModule): train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures( - FbankConfig(sampling_rate=8000, num_mel_bins=23) + Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=23)) ), return_cuts=self.args.return_cuts, ) @@ -222,9 +222,11 @@ class YesNoAsrDataModule(DataModule): logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) + if self.args.on_the_fly_feats + else PrecomputedFeatures() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( From 68980c5d0abf41286a59e3596321c57691deab4a Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 17 May 2024 19:45:15 +0800 Subject: [PATCH 164/216] Fix an error occured during mmi preparation (#1626) * init commit * updated --- egs/librispeech/ASR/prepare_lm.sh | 2 ++ egs/librispeech/ASR/prepare_mmi.sh | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/egs/librispeech/ASR/prepare_lm.sh b/egs/librispeech/ASR/prepare_lm.sh index a8eb5ca78..1792395d8 100755 --- a/egs/librispeech/ASR/prepare_lm.sh +++ b/egs/librispeech/ASR/prepare_lm.sh @@ -31,6 +31,8 @@ log "Running prepare_lm.sh" stage=0 stop_stage=100 +. shared/parse_options.sh || exit 1 + if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Prepare BPE based lexicon." diff --git a/egs/librispeech/ASR/prepare_mmi.sh b/egs/librispeech/ASR/prepare_mmi.sh index d8a6e0caf..f877cac22 100755 --- a/egs/librispeech/ASR/prepare_mmi.sh +++ b/egs/librispeech/ASR/prepare_mmi.sh @@ -8,11 +8,15 @@ set -eou pipefail . prepare.sh --stage -1 --stop-stage 6 || exit 1 +. prepare_lm.sh --stage 0 --stop-stage 0 || exit 1 + log "Running prepare_mmi.sh" stage=0 stop_stage=100 +. shared/parse_options.sh || exit 1 + if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Prepare bigram token-level P for MMI training" From 0df406c5da5dfeca9721ccfb26f5164c626ee4b2 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 20 May 2024 22:32:02 +0800 Subject: [PATCH 165/216] Initialize BiasNorm bias with small random values (#1630) --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 29ac33c02..fb2bf1b79 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -462,7 +462,7 @@ class BiasNorm(torch.nn.Module): self.num_channels = num_channels self.channel_dim = channel_dim self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.zeros(num_channels)) + self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) self.log_scale_min = log_scale_min self.log_scale_max = log_scale_max From 1adf1e441d7ad49d5d4a96246a28aa9e12d6f967 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 21 May 2024 18:22:19 +0800 Subject: [PATCH 166/216] Removed unused ``k2`` dependencies from the AT recipe (#1633) --- egs/audioset/AT/zipformer/at_datamodule.py | 8 ++++--- egs/audioset/AT/zipformer/evaluate.py | 25 ++++---------------- egs/audioset/AT/zipformer/export-onnx.py | 3 +-- egs/audioset/AT/zipformer/jit_pretrained.py | 1 - egs/audioset/AT/zipformer/model.py | 8 ++----- egs/audioset/AT/zipformer/onnx_pretrained.py | 3 +-- egs/audioset/AT/zipformer/train.py | 5 ++-- 7 files changed, 15 insertions(+), 38 deletions(-) diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py index 66497c1ca..ac8671fa6 100644 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -373,9 +373,11 @@ class AudioSetATDatamodule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = AudioTaggingDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( diff --git a/egs/audioset/AT/zipformer/evaluate.py b/egs/audioset/AT/zipformer/evaluate.py index b52a284d0..0a1b8ea5f 100644 --- a/egs/audioset/AT/zipformer/evaluate.py +++ b/egs/audioset/AT/zipformer/evaluate.py @@ -29,27 +29,18 @@ export CUDA_VISIBLE_DEVICES="0" """ import argparse -import csv import logging -import math -import os -from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict -import k2 -import numpy as np -import sentencepiece as spm import torch import torch.nn as nn -import torch.nn.functional as F from at_datamodule import AudioSetATDatamodule -from lhotse import load_manifest try: from sklearn.metrics import average_precision_score -except Exception as ex: - raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn") +except: + raise ImportError(f"Please run\n" "pip3 install -U scikit-learn") from train import add_model_arguments, get_model, get_params, str2multihot from icefall.checkpoint import ( @@ -58,15 +49,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) +from icefall.utils import AttributeDict, setup_logger, str2bool def get_parser(): diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py index 24b7717b4..2b0ec8b4b 100755 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -36,7 +36,6 @@ import logging from pathlib import Path from typing import Dict -import k2 import onnx import onnxoptimizer import torch @@ -53,7 +52,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, num_tokens, str2bool +from icefall.utils import make_pad_mask, str2bool def get_parser(): diff --git a/egs/audioset/AT/zipformer/jit_pretrained.py b/egs/audioset/AT/zipformer/jit_pretrained.py index 403308fcf..d376aa148 100755 --- a/egs/audioset/AT/zipformer/jit_pretrained.py +++ b/egs/audioset/AT/zipformer/jit_pretrained.py @@ -50,7 +50,6 @@ import logging import math from typing import List -import k2 import kaldifeat import torch import torchaudio diff --git a/egs/audioset/AT/zipformer/model.py b/egs/audioset/AT/zipformer/model.py index f189eac62..fb8e2dd85 100644 --- a/egs/audioset/AT/zipformer/model.py +++ b/egs/audioset/AT/zipformer/model.py @@ -14,17 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import random -from typing import List, Optional, Tuple +from typing import Tuple -import k2 import torch import torch.nn as nn -import torch.nn.functional as F from encoder_interface import EncoderInterface -from icefall.utils import AttributeDict, make_pad_mask +from icefall.utils import make_pad_mask class AudioTaggingModel(nn.Module): diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py index 82fa3d45b..8de60bbb5 100755 --- a/egs/audioset/AT/zipformer/onnx_pretrained.py +++ b/egs/audioset/AT/zipformer/onnx_pretrained.py @@ -42,9 +42,8 @@ import argparse import csv import logging import math -from typing import List, Tuple +from typing import List -import k2 import kaldifeat import onnxruntime as ort import torch diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 0e234c59f..2d193030a 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -41,7 +41,6 @@ from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union import optim -import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn @@ -632,7 +631,7 @@ def compute_loss( model: The model for training. It is an instance of Zipformer in our case. batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + A batch of data. See `lhotse.dataset.AudioTaggingDataset()` for the content in it. is_training: True for training. False for validation. When it is True, this @@ -1108,7 +1107,7 @@ def display_and_save_batch( Args: batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + A batch of data. See `lhotse.dataset.AudioTaggingDataset()` for the content in it. params: Parameters for training. See :func:`get_params`. From 42a97f6d7b6a700f8207f66a7520d15f0d02b994 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 22 May 2024 22:29:38 +0800 Subject: [PATCH 167/216] Update env.py (#1635) --- icefall/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/env.py b/icefall/env.py index 373e9a9ff..6ebc30f6b 100644 --- a/icefall/env.py +++ b/icefall/env.py @@ -108,7 +108,7 @@ def get_env_info() -> Dict[str, Any]: "torch-version": str(torch.__version__), "torch-cuda-available": torch.cuda.is_available(), "torch-cuda-version": torch.version.cuda, - "python-version": sys.version[:3], + "python-version": ".".join(sys.version.split(".")[:2]), "icefall-git-branch": get_git_branch_name(), "icefall-git-sha1": get_git_sha1(), "icefall-git-date": get_git_date(), From b88062292b8208bacf429fcf2d020c08b0333dad Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 3 Jun 2024 16:49:21 +0800 Subject: [PATCH 168/216] Typo fixes (#1643) --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 28 +++++++++++----------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index fb2bf1b79..e7c3f4ab1 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -137,7 +137,7 @@ class PiecewiseLinear(object): p: the other piecewise linear function include_crossings: if true, include in the x values positions - where the functions indicate by this and p crosss. + where the functions indicate by this and p cross. """ assert isinstance(p, PiecewiseLinear), type(p) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 17a3f8719..69059287b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -205,9 +205,9 @@ class Zipformer2(EncoderInterface): """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than + On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. + a smaller encoder dim. We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have @@ -548,9 +548,9 @@ class Zipformer2EncoderLayer(nn.Module): Args: embed_dim: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). + feedforward_dim: the dimension of the feedforward network model (required). dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. + cnn_module_kernel (int): Kernel size of convolution module (default=31). Examples:: >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) @@ -1028,7 +1028,7 @@ class Zipformer2Encoder(nn.Module): ) self.num_layers = num_layers - assert 0 <= warmup_begin <= warmup_end + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin # interpreted as a training batch index @@ -1177,7 +1177,7 @@ class BypassModule(nn.Module): def _get_bypass_scale(self, batch_size: int): # returns bypass-scale of shape (num_channels,), # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 correponds to bypassing + # scale on the non-residual term, so 0 corresponds to bypassing # this module. if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return self.bypass_scale @@ -1381,12 +1381,12 @@ class CompactRelPositionalEncoding(torch.nn.Module): when encoding absolute position, but not important when encoding relative position because there is now no need to compare two large offsets with each other. - Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the fourier transform of that fixed interval. The + Our embedding works by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the Fourier transform of that fixed interval. The atan() function would compress the "long tails" too small, making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) @@ -1408,10 +1408,10 @@ class CompactRelPositionalEncoding(torch.nn.Module): """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() self.embed_dim = embed_dim - assert embed_dim % 2 == 0 + assert embed_dim % 2 == 0, embed_dim self.dropout = Dropout2(dropout_rate) self.pe = None - assert length_factor >= 1.0 + assert length_factor >= 1.0, length_factor self.length_factor = length_factor self.extend_pe(torch.tensor(0.0).expand(max_len)) @@ -1555,7 +1555,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero # bias because the small numerical roundoff tends to have a non-random # sign. This module is intended to prevent that. Use a very small - # probability; that should be suffixient to fix the problem. + # probability; that should be sufficient to fix the problem. self.balance_keys = Balancer( key_head_dim * num_heads, channel_dim=-1, @@ -1571,7 +1571,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 ) - # the following are for diagnosics only, see --print-diagnostics option + # the following are for diagnostics only, see --print-diagnostics option self.copy_pos_query = Identity() self.copy_query = Identity() @@ -1609,7 +1609,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim + assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim) q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. From 130a18cc10e614a88e3afa992add2ded109aec8d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 6 Jun 2024 22:27:29 +0800 Subject: [PATCH 169/216] support torch 2.3.1 in docker (#1646) --- .../scripts/docker/generate_build_matrix.py | 9 ++- .github/workflows/build-docker-image.yml | 2 +- .github/workflows/run-docker-image.yml | 2 +- docker/torch2.3.1-cuda11.8.dockerfile | 73 +++++++++++++++++++ docker/torch2.3.1-cuda12.1.dockerfile | 73 +++++++++++++++++++ 5 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 docker/torch2.3.1-cuda11.8.dockerfile create mode 100644 docker/torch2.3.1-cuda12.1.dockerfile diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 77dccb93e..7f13c59bd 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,13 +45,14 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20240223" kaldifeat_version = "1.25.4.dev20240223" - version = "20240401" + version = "20240606" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] torch_version += ["1.13.0", "1.13.1"] torch_version += ["2.0.0", "2.0.1"] torch_version += ["2.1.0", "2.1.1", "2.1.2"] torch_version += ["2.2.0", "2.2.1", "2.2.2"] + torch_version += ["2.3.0", "2.3.1"] matrix = [] for p in python_version: @@ -71,6 +72,12 @@ def get_matrix(): if t == "2.2.2": k2_version_2 = "1.24.4.dev20240328" kaldifeat_version_2 = "1.25.4.dev20240329" + elif t == "2.3.0": + k2_version_2 = "1.24.4.dev20240425" + kaldifeat_version_2 = "1.25.4.dev20240425" + elif t == "2.3.1": + k2_version_2 = "1.24.4.dev20240606" + kaldifeat_version_2 = "1.25.4.dev20240606" matrix.append( { diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index 9198cdb7f..23dcb519f 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index a26e704c5..336d930ca 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v2 diff --git a/docker/torch2.3.1-cuda11.8.dockerfile b/docker/torch2.3.1-cuda11.8.dockerfile new file mode 100644 index 000000000..545b42e9f --- /dev/null +++ b/docker/torch2.3.1-cuda11.8.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240606+cuda11.8.torch2.3.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240606+cuda11.8.torch2.3.1" +ARG TORCHAUDIO_VERSION="2.3.1+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.3.1-cuda12.1.dockerfile b/docker/torch2.3.1-cuda12.1.dockerfile new file mode 100644 index 000000000..ca13752e4 --- /dev/null +++ b/docker/torch2.3.1-cuda12.1.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240606+cuda12.1.torch2.3.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240606+cuda12.1.torch2.3.1" +ARG TORCHAUDIO_VERSION="2.3.1+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall From 4d5c1f2e60317bc27a791982eb7e1fcb3603385d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Jun 2024 22:41:54 +0800 Subject: [PATCH 170/216] Remove inf from stored stats (#1647) --- icefall/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/icefall/utils.py b/icefall/utils.py index 0c509fc17..1dbb954de 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1159,7 +1159,8 @@ class MetricsTracker(collections.defaultdict): for k, v in self.items(): ans[k] = v for k, v in other.items(): - ans[k] = ans[k] + v + if v - v == 0: + ans[k] = ans[k] + v return ans def __mul__(self, alpha: float) -> "MetricsTracker": From ec0389a3c1c0cbe93fd04185a20507e6447ba13c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 12 Jun 2024 17:36:57 +0800 Subject: [PATCH 171/216] Add doc about FST-based CTC forced alignment. (#1482) --- ...-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav | Bin 0 -> 108844 bytes docs/source/_static/kaldi-align/at.wav | Bin 0 -> 2620 bytes docs/source/_static/kaldi-align/beside.wav | Bin 0 -> 14206 bytes docs/source/_static/kaldi-align/curiosity.wav | Bin 0 -> 22576 bytes docs/source/_static/kaldi-align/had.wav | Bin 0 -> 4550 bytes docs/source/_static/kaldi-align/i.wav | Bin 0 -> 688 bytes docs/source/_static/kaldi-align/me.wav | Bin 0 -> 2620 bytes docs/source/_static/kaldi-align/moment.wav | Bin 0 -> 9702 bytes docs/source/_static/kaldi-align/that.wav | Bin 0 -> 4550 bytes docs/source/_static/kaldi-align/this.wav | Bin 0 -> 5194 bytes docs/source/conf.py | 2 + docs/source/docker/intro.rst | 2 + .../fst-based-forced-alignment/diff.rst | 41 + .../fst-based-forced-alignment/index.rst | 18 + .../fst-based-forced-alignment/k2-based.rst | 4 + .../kaldi-based.rst | 712 ++++++++++++++++++ docs/source/index.rst | 4 +- .../export-ncnn-conv-emformer.rst | 4 +- docs/source/model-export/export-ncnn-lstm.rst | 4 +- .../model-export/export-ncnn-zipformer.rst | 4 +- 20 files changed, 787 insertions(+), 8 deletions(-) create mode 100644 docs/source/_static/kaldi-align/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav create mode 100644 docs/source/_static/kaldi-align/at.wav create mode 100644 docs/source/_static/kaldi-align/beside.wav create mode 100644 docs/source/_static/kaldi-align/curiosity.wav create mode 100644 docs/source/_static/kaldi-align/had.wav create mode 100644 docs/source/_static/kaldi-align/i.wav create mode 100644 docs/source/_static/kaldi-align/me.wav create mode 100644 docs/source/_static/kaldi-align/moment.wav create mode 100644 docs/source/_static/kaldi-align/that.wav create mode 100644 docs/source/_static/kaldi-align/this.wav create mode 100644 docs/source/fst-based-forced-alignment/diff.rst create mode 100644 docs/source/fst-based-forced-alignment/index.rst create mode 100644 docs/source/fst-based-forced-alignment/k2-based.rst create mode 100644 docs/source/fst-based-forced-alignment/kaldi-based.rst diff --git a/docs/source/_static/kaldi-align/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav b/docs/source/_static/kaldi-align/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav new file mode 100644 index 0000000000000000000000000000000000000000..004a33532ea2547c10c0074b967733ba91edd9f8 GIT binary patch literal 108844 zcmXV&1y~zP*T;7QAqj-wu5F?2a_jEye(UaX>+bIE?(XgeHR>(YfkJV2$nN)>yx%j= zZg$7#%$alk=ge%nv~JNN=%g(5Y}}*y@NrWD9VJPU2@4wVFjtZ^lA|L@P6Ey=gLwN3` zX-?>>>7c2isj2D2Q;?=WcGQ&7IB5(TdyQGnmh~DJO`aSnhs&YzCHaheN4_i{miNf} zyVx+gyLn%c1A$^ffOBNEAV`c!?SURBSi*VOy!QFW8LSG}R$QtzlQ)!*d&t$rb2hFZuyN{v^oq&W4Dny)%Z zh3YHvsH#cIQd3n~(n^2TNHtG&lY%50$w?|JRhELKl2UQ0hSWr=BQ>Q*U8L4hEADNi z?ow}QhO|_gCXFR5lvYWrr7f2GG-(d+>!fX@kCG;GpH12}M!THc^Q1kDc!xAg+8~{h zb}{bF(j{oIR+<7e4ik4yI&TT5$hDCcrbxqSYc$kZAuVId7f@z~G?V)}YMVpJIno$v z9wZHsI?~5FQZ1>0RG+IR&-IAuA$4LL?u@5~R7%o8(O}6zvX%^z3(pRcGh_56-kN7$ zC|*qRWBeIv8ua+Y^$yxzP!Dszp!{7R*H&Ar&DDBpS+$tjK<%k^QU6s(lB)@Mr%}&TLKn3iH4Rs1t7B<>nmU~}*3jcC z>UH&=`hpSsf`{G_BGd=)Oq80X=Bk-+LM$AR$^56PM!3ygvgIl-Rgp?dWu%I5U?5a$ zB6Wi!%S#QVy3i(2DhFL#NR6c0@IqUtT^nk*C#Dme(Zj+S9f+w;?xvJ>fu<#(s4wM; zNv@I?EvRsy0s5!YgKzXEA5Q(EekR1h=ik&s;-iT9K#!l&;~VNJXnq?y@8kZI+1RY^ zS5GhtThwLh0;s;2@vem8gQ4X(sIZijb?P?4MkqgDodI<_sr}WV?qRuSSUG}`I$ob;e@Hg_u}_Io+k6$QyoUB$?9CfN_7!+j)nfCn1`m!R3ql77PHcx zr)s>{RNYh`)n6^HdZ`|2MP{=YPZiWwY8BO0bySV2s-!7)svog>HC53O>%`@umQZcf zze$R}LyClpD%TS)eRZrYaMa{z@;Uw=!Dkp>$NbDMOS& zy!BFME1Q)q$|hwa`DQBfl$FX>WrZ@I5@VGy7 zr;JqkD1DVK)IXd$JM-I5nW8Kt|7K+mwXIk7D?637w6cO8EugJ+$_8bJvYYtLw6}r0 zYxUzmDR?~ zMpfuupZRIP+UdkR4p0X%cY-rU!Vz=egE_p-V)oa=UCWu@O+1}bFR4%9z+38L*3K{W zBjG#0ui@2iaOgudl$fV*=tZPyEPVSFIh};O%~q3K?B8RT|j-m0U80+HIKkmAjeEp?FD zK9UEu1R!zBlG9PrOD1HA8_y=V5ujPwE^#akezj)LnY*+s`^3`A$`j; zl8W^DAM#Y?*HubXQ`8()^v9Ey3e*BMhIa#P`_bpWNYOZ?c@BIQ3tzn^d{GsoYYHh5 z>Rb5h30UDcQuH9RcN<-H7+$*yauAy0J{#vLE`i@?;G{>a;-~6!E}<3U!9{tjz0c|!DD{bvMNxAhqY*9W-~cUDHPIhm z`dJd4_7B`rns;~J>hWBalu~eO4{4}0P#P~yVopbcMJ7p$2+O4TV3U=go1NU3fM*VX zgLX;Bc|RrHlAdr~l`eyao=UHzQ0~u2J1N~D=AQHoym(K#P0BCn5BCr-WEA-E8wfI< zD?>8KT3I9OWgT}PaHhL#BO7?v$_0`Yan^FKlqMxf2~s{+IwjJiSblSe%aC%VLMc*u zExn`0DBAcT{pRVZ^g{Yh_#)k*{qI~Sq~}tE^ca-(n7j|cXa~S>*THFz=+#^4rxYjs zrRHSXD3StXmER;wYq31V(OM*R=Syy~i)_nP4qR-ME6A1Q;yiitRnf4Le6+Mbc*Ng6qyTGFHf+WID?THdpHkE8w{+>;qq9Mnl>6z`=l#`K83o|g90 z*S*kJ^zyNEmiHsldQj*ZaN~I9X%b-&Yp*Y>uL&!&JTqC7-%8+`e#~S`sU2xUdG0_; zRbqQs=C(8MU8NDyIPmSi)YFZety#aVscSs>dy=y&SXAJlMo5;n(g10g)S1v0L{!&; zo}00EdJafe&miV!CUov=>ETRP%@Us4 zQ+ARx0jps)z1%{NmxBIxL7xk(uFZ^dJv@1k)pLRwpT}s0t+0R+y_ktMwA%^VH>6)a zNaGgFLNg?C5UbZ2Nn=3jyCBbu$fHVNRV!q$gsfFT$>~UiSl;rHeThhmG~}9)O0SWO z&ybWM$eH&@p4VvaP-Jfs@=n+bN3aWaB7q(uSx=&~Zz6e~@O+1yFOfeV$$c74zSx4& zFOlyX_1q!6BJ~UyAPgz~hWjge5KbR{^K_s1x9Hwaw0(g3A0oT%QBw%Lxkt)-;_uPk zaaz7ko#&ABN5RP#E%xF|;%=kCLy&;SD1V(^?zSN3&1nB^^kgTvXE{&H!5X{CvD<=* zW`YBHfnOVepZj2qlmk=Mw%8{1)w0_qg(k zyaLA_1mP|w<^;HOD=7H@NO%p{cRHAR80fq!2zrPz5KKK3q}`7&j$Eri=i|ZjtH9vn zLG*h;$UBtNVA$)R-UyIv2)(>T4@En-l`Hi24$rYlhVqx#-%5^>tc3HPuVisgfx04O|-i^Oco(`cl} zAY|<czfXuif?wl2QfMbF&qSWh zMDpxLGOa|yO`?T)$S>igSb)^qMEl#3;b$$9`G9nha0$u%gIYhMxk5;}Eqy>6hY`L= znP|--LOkJ*^bzeAN1O?L=_C81JTB@;v!~d>!A_8Tqul7w^?IIYj;;f0tj$FXadFE%_!st^4w4`86>Q zNxd%Llb^^RhGeJpYli*Y=IIr%(!*2_!ejq(OcoW_HA z3Hoo87s+dwn{9XzH_4mHu@9=>An$YeE#)s0zmw9pxOU2`R(3F&*NpxOZ>Qz0^kgG_TSk9o$%Et;a$UKl++1!7 zhN}cu7MESYeLiri8z|KoF0DwYO1zisB+KBwR4E%iF90bUMn|7Q79O{F8J6G? z9D}E+8(OvuvbiL_BP%rsUFf1}R5KDi9hsh|m{fOks!>&xB;@-y^g|TdEE!#ptmL9g zQ;GXV%x8;M6RAR{eZWKVjkr%}jK}Df`)HEe=&TD|_t0MF@GGsyCv$?|HE5)5cvqg1 z^8#_#l;?z(luxAISCoB%SLPJgS2Sh{IwwmhRN_egO1|5qoZ~u;c0GeOTZAWSlrjSC zwpCe%pJ+W^pk>MgbmC%S_FCFqg2!nSVHWyu6CS0-Jk8@ij(julL$y@uD|NXVDgP*K z@Nsoknkv3Ya z*mC*MW-Hl0%Ec)$@MI|A3Ld=|)cPErdx4)X8cr3+EL(9@WiVq& zyr^ERh(b6`PmMZiN~Deq)`SDT(*n4^(9%{ey>V4js39KCOs1DvpuaSNmbP<1f2q`9 z01gO{sOwgvUH0AJMs70XZu6NA3$EZA42cLE!m zL2IduD<9AAFV^;NXq2r;7Rp&sBA5P$F*187tY@6A(78~tfx^y=$<;#J6w7-3zL_I-(LrI$r>f4Q_ummsqV&3L+{R@ryQF0oV!#t4R3MjJ@n(VaXnMO;4 z@LTq#4`Zph8T4(zTT^-{aDNqkYw%W{(KtgJF)}A+$e!^@*a#}LiASmlG%PUjbGYUn zH2I4>{0+~9z@cy8sC#hfT_n{@5c_lGz49Ge2qattPkcpkra z-LL8Q%XnV_PToVUi=pmrkn$=hIfI)1 zg;FBF*mLN|oD5(-yIFRAIzi8l@Mb?!hLdkNDedXcf6Ui>u=rJM%B!?-5wG7Rw9hVe zJD7Yo8t51{=~;04ZT3{Q5O)Zx@ddW(Wn$0L<5Rp}W3S~j*5X~P*qhuRQ|1!3XE;{u zBP`Nu+#g|g3aj=l7O3#9gBHy3K?(V%KIe;}8~u@MJH9 z$7ayyUT{Dk=r$a#{LeD-apdYv=*+0PSuDPPuuZx$-X4sy znox((hWTp3b0^|jkf*67wHo@h1fdK(SPFUIjzn<7(kzMH(*!%S9P*$PrP~oYQmQNT z9F1Rn9z7pH=tCICEY78m3)Iz&YZFvG40W#(PC&D**g5BkIR!<&La&?HI$yAKj*#aD zGKAW;Wu8*WV{UtcrufC3xVolFH%_kVfYyG@Oz4#vT)|* z0kyO_`ElbePs`w7;)615YJ_$OtqsKkz{RHg5so11q$9g7~;T;R2ypOFXJgzZ#H)Xt= zX}H9T@Y(vvrLpMz!O5N=S{HsD@s+CJ*(`jclAHiyO$MKS{(lSc3s&nR=_R=J8!6f3 zkU`6~;9P-%bHTPw)K*L`E7z3kU@Z#IYAJlK&R}69m#`#DfSUb@txTSBwC)Bbmq5&3 zAonsL`G#@>xecKMc4Z5>I%gaDg7Le{J>CZvxynrQkg%W4z z$x6;?%#?>Qj=_XBa%;Jl+*NLZ<=hObr7<>iHSCu`jIcT5Z_2pEZ-03NCE9bkqZ>JE z%dMb6d&>U9Qx`%imzft{|+YH%pkUA$SDlagCRU((CT@ZW1ld#bYpy9CLV@$J11DjOBM0 zG3&5UchQU8mar3Q9)hL^X<-vF!dks*nYD9Riszv4Ddy%PPe+;SS=2OxK1`&&NsRwL z+FDKT#V8j+gO&7by}W?8>ExLx52Q~mnT?^)U?Md87aEAZPL%t}J(;H-+$YdyF?(H@ zi2>y4${f|U1i|G^<$Cx-%3~Q?v$FD82T2x7@2`|1eP!Le#?$gty24r%K4RfF7W)GG z@gi@+v%HQ~IttItbfnS*B-MK2b`!gbr$hLcw{hNN1vZ~p%`5PrufSiufHgjv_l4wH zhp%S>GHVQVjHM1?8xBFT%^=SdQrD1Y7dt2y@HU-ckKzRW>ofR)j`4hvrxW;<9?;Ga z(jTzPa-H@40t@sa@fR)r?K7lY$J2ikoAfsOM<=;|vPgtbe!pVN#^R?6WpCgC`$tb0 z?Rz{{FD>JKDcvXjHg(-Va@<22U8cO)FWALM#91CO>hVbXe%P@6kj9fp+h!T*Zs@a~ z(QmMfSe#ZGz&M4+Yc9{*7{_+#xt3fTp~M_0whSuHL>Ek_=1Kf6fdbP>+dz(u$p5{x zeHhQ#N&H@iD7}ZbP2^fmEy9!kKUpbu#IDfa?UdO-y72O?v-D#%)E&cYO{WLq1eNf! zZGg5%(KVy!+cd%ydLZ^1`Z2b}lv+h!Cs9u9>9j(pR7Ddyv(x9m{$UY2slxX!_Wjb3 z^wH?jKj>0P68iZUnpm8|5~rxXp%1R1OTZWo$AowgF_kpsIDUhHn#)3aQ97DRBi z^w9~NAoiub*kLxZ<6H?n^uMz>EiHPoH2cybRqTuhl3I;DWpSFP8oFNWAeX?uYDW%d zN;{y-yvgUtJ&4m#eqaS}LS=S*YJ&&FUVTrr_fWKbJG69NV%wtO>a(NL7Jiv&(c+UW z+%O(`iybDh^U{t|L(%LLp#2hNU<5hl@!OUCn4W0*0eDjWCI3R^WUxg`j%BU{mkgqn ze$+dklB?mnRm|}QIOY(XD)z0G!vovkl6|b)qn7hxhlyE9{z;?zqf9{qeQk2EZiDD0AC?4Nk7L=mh& ztHR=T1-sS&#TLhgDvo_pj`JN=K(rbx!arCpDN2!Ii$zYqXAD|GK3=5)LZKp4 zsvIJg*j?xd*8i8^W+454v9z0F%?twT55txg zp0N>N`axKabFdM_*)DPJZa*`29+bV0SvY_`#MIgf%)9`)vUIsRsosWAm3~^xllv&N2!oto9g2oesV77;Q2Z z=Nmj;Pq5Tu@K$^#MDlwN9=Oc;rW1ty%=#(za?bEB{3e(2r`%@OXgfPa=PkasYn*C& zNO(`_cNTj+jya2mc2@ZQ^z2w^EOX_~+}UC6XDE7R!2xf92`@z;^In4a543nE>R}Cu z-I;-0y|G&U=XYoYB}c>08{p_!1Yv)!!Kzw-Ez=JkUxC%K4U6Qw#Rj~IP4WqA;|x^Y zi+A!UX*;oF)^MFeYF&lK;>5^3D1Mr73ORQj${*+VIqT;Z^RS8;TFdV_>O5gtU3c+& zzNY@u*q3LKk58D3-8|hzf(m=`DN-;3YsZ?isDF{K;v7{JR}zvi1%?w+%!E9u4bJwk zoX7Cz?!%g{KumSwJ0mZ;AvJ~|ZH6M5YO&`m&VIH>ay4Pkr~^C73oI7L3igmkB724* zDTb4`FZXFkpyg;eu?yXknEptmI_x>sL_^f(-i3M^B1eRjEKaL_+>7z-3H}$7C5ZDN zRj@_8k*GGDNlCC+ayeLKN!V$5+{NjYLa6YYQHoPL-;wrzko5oC;|)jhzd<926T9)G zz2Z3n?I3oEUn8x4b5i9yHpy3F-ZMYn(Gs7KV4r#W4z-dw0aHjT!t%35p4(svDd?1V zcGJ?ae*U+M=8G1n4rLlLqKfpl0Vk^(VdJ%cQsRtJ6Xd^dl-SBCq^($ND?QIyv=l)<<7rd$A`A|C!mQqhR}N$Az9!}tTyqW{xJ1km zEUTMbM@ScI_%iV!-6tT(|fYzKT;8D_NXg4TUZrp~@eNW{$O3&e@z~E{8Sf zglrclOiME|p_A${((1_n28>B)sp?pXO_=S8%>5Yj${6rXd-PKm^7drjhog6fqQ$0L zY}irgj|p(d2>4+TD5x|1($Au;I>AXzp;lSu-Wj?&L#5KtMS&9Hly(N=7n(dB>O?_< zJ3K$zrCL$vz;xQUTN?US>mHjVrZJtGTPNYuZi`xaAFR-pJK;@m5 zm3WbLc$VuTb09oC|2sJ@xauh}7rBShUy*u|75{*-h%qKJ)--tXE$Mfd<4?>2d z6meE8nf0e*r5ZqaE__6gPJVHQPGCS!q>4EEXUEyPY~+p#kBN61cw7sQ={Z?eqza@| zo)urlB7Fqrs?9nQCw{~U`2on44($9i0AC8MRgNbiAxg7~6etr)W ziW6(E;gYvd={B=)lTn^TF5kA`&d-c094fwJ4(^ldK9oEI51kbUXauT72h4V2*TB zg*R5lr=Lj<8*qn_TwZ9&K=gJo3$}3Ly#$Cu_>}{wS@`tp@>~mCQ-`PK+$(}|nu9mQ zx)EN)HjGM~hZWePFRN!Rqo2>(+zw?nb3bK~U?-WOCs6VV;Sy9^%?uoXx&lj{44b8=V!FE=okcNk#`C;a6Hxyn)n(z-#gTf@Mq4Vgz$2XV`K*z(E<4J4D#zMt2%`hBy`SKR@gU- zG!`o_11m|K#tmc*x$x`9ukbMHuv(Jf-AIdG_=`*wXF}f+!jOnUDvHykVes~M%gK*y z)<-xszvFtyx)#1TaV{j1wqmH~6)XKARKE%*3;w@=Y&wN>x=uasn321z^)G~y}AaB+Icg zwqb)_#U6i%^?nLF^AL8cc)QN^g+1lJcp75xHT+_CIf}Cy8N@^rCr)of;eCiB&nG-} zVt+fCJhAL~%beD9mOZ#zvtw<r^Q%LJF%M%;$;w?g@g2bCum2UNfBOJVV5nz zvR(x;nn=t{bfnnP9*IVrfGs@*4cHrGB~B}MMr(?%LFz)mhWrZr&=h(KI@Ymh+UDHL zB1bjgpd2V4iN5}fWDaFULs%{6ExO>Yh39u7c{aj{VsB&vl4=2*I09}T!AfiaClBSf z6_TSfQlcexNj)r-KvrxEo=dZ~E3wuEI`=_#=+zADryTYn9k6?TVncnzrV`%}h?88; zu)74({y?5cF0mu@hV*w>Vc&>}M&hJmw`40Ob^%h z><149hb$dN1(IB`SCEG;wXEK zHG7|MiIpo5#~u3g1j%xe9*A@0uegg7*uN?DlYTy9oxdZOSkdBi`VpwHp79QVn*A)h zGnJ9-!AN)~ideevDz;Mw&rxtd0(wI1{fd)?)|P!lZ)UZ$`VSf)2n|sQOSl~G z9@Jokr6P;#44P?dK<^nVtg5A@$6l& ztkTQKoZYPB=Y$7%0l(qpdr0at-rgZ`#3{5eR=+qO_M5XaKZz9{z-Z)-I60Psgo$V8 z;4iWv5s6b|*+;NOu4wosF~K4^qDe_dKBbY{%#J}4zltSa7P3j4xzQn$bQXDJZ#i#n zL{8ZgoY_b4MM^nwy2pgnGLcJ_Y)J7%VtKKnP=%d?a_lQq;*57)xxU;MZ*Dhy=KTp> z@UvIP(_V*Nhx&NY>#*}sf_&wv)t+4sZ+d99NV;UCTpB&fq+hAfAeMfMbKt)iTRNwL za_M~}yC2b{Bohthv~`6}V(-L@{?}*sqX(Yz5%|kHa1FvMJ&3UgkN-sc)3Yrnx7Op$UWosE zEpPkqzF%Yi;xQ-3qB(gbYxD#i=gn+1D(Bj?8aGa)Iq+?%_{KCy6HF+tsmyt{ikjM* zdYnuvr>Vu6G%wPfIp^li-N|yM&Xr#~&bVo*Ctvoao&d`^wtQKm$*_E-YNZip?Xo!E z<*fP2IkrEXs7vBBSvbFcIJXzhS+0BRmprxXtQ=vlMC?5rq`x=unNap$ zF0ku!k#LYU53@UPoV|iW#NME`edPYl-os@mHkN&j(>#x6|6mwl9Qy%%@ccJm2dV}q z$7;b{HQ1jj$=41JmL02N%$pUwkj8p=!92faMLfaJe+|7SzO>zr?{_I0PWbURV&yHs zJ3ZQ>>06=C+Jey=U>R4m*uhmn@0Gy`VkcI7lT@4)EWS0WL5>C%l+c8|zMib*+7??{ zoQoGBm!fYyffQT#u)<0Lw^xW8#AyQ4Ss2+M-j1 z?ktV8C=Q15CXazuD_JB(J6isS_F7r|n!;+YgB@R!KK<`&scQ7OE;dRv;)-EWG(;*p zVt0rylM3)hXr*t+CtK-0Qn>_QM_os5IZM~jGar$-+mUCn;HiaZo-637wb)q8&{JX; zTj0g(NW6QXsT-i&gJ`WyV8Y|*qNZw~G!NbUhHr6tqq)D4V-C9NFM56&lKUFka3yUV zLVj*Un=L_li*Hlcqx)u)e-XcPEmCzGal6oCW63`c9XW(l@x|I8^3Npf1dR?ulMY04 zc0wX|M=tlkQ`rzCUy**)$4glsja>%MU3n~kKy-KsBdb zdyUWS1IXYiC$9J7p?iu4_db5Q19<8#TfDZr@!D>*`0rkV67F-N+BXHJ+LBsC2kbc<3?z#R`Le-HE#PYox*C5g0to!__k2*W@e zKd3iX30C9DD^A));`PfVB@WMND){9Nu|`R{yk zWrJpNXzL60f5yWse374tyU+7y%PE&f^hfmTHFdwHmk+>0f6I~ zTl!*YGT^J0*kSd#w*n)r!Qz<7Z%;U5KOESew7cloomeMJ;LP{n zhvo2z_y+eAJl2-Hhq3)LKo$oqXe(T81RB|e7vU_Hy0!Ejq!C38SHU3;__u=C`>2C& zr!5+~1aD1PwQbljm}J>i8Vmk!gjVj%N*8-z0j$I($cYu~PVEFG_hDW2;cY$=WG!p9 z95pUPyPN_!w_{ghDhPWy*26^POn+q0Njy#B%c0Zgn5SUr)5xV4=prN9&r^Ou$*-Jn zHKBXl(ReM6+n~ z5Uq%>w#5Ge%qQ1zG`ZNjJq|S|VW$m3p0A}X(f%9y@&r5Q5?0VcT3*Vm9i#MSTEEOJ ziQUY-+$Ye3z08r7<{2onqI^o~FMn6-N=xL;ps!)f_90rX0WIcJ_gQ+=0o@|@waz2K zhl136EW1-p&>}stc-GRNTb8pe>!|qwmeg+UGq4SBS!i&S{@;XF&%qqxo3tH$P6x{z4?IgZHy~dux zDfGuJw7_cc_7&civBuxB=1-ufPLXmSTW6MKPeFJnZqTCmzGO3+>k@N18~bcG{(&WU zLg%43W>H>vIA$SB#JRHBXeOaqdV}&u^VW;|3@p!?2$>%EtsS@V$+mh=TDFC%nEOrtR^GTi_j3_#BL(;KLs8572Wa@S+Am5=5Q+e5MS@_ zN1Crg``zPI_9^t&Wnxco273hOfmgAQwV%Dr)9hvLLT?Dkf0-DeHxIGTc^6&uh^Mo} z??5-6BYqP)aWCO4rLJ*;_$nGGoD=0|vB*zS`UtyV2PwCQ^h?Bw@BPn^>ysr0{&RbtcrYVnc0QJi*NNd!4dl@H|OH;c$I&w^nNkYDUYid~DN=&60Y9|BpOqc-s+;C=4m`~D3UkD54r zxfJd`N>7e~tJa{`PLZ~ka1yO`oAO6UJIj3+rLVw$Az0uav8sQtR-VJ>H!PmRJFLYI z;JL5Zq~dGtr?hyHw!}`|S>k_UOOMqBsn{QW3AbZ^0J&|_y>*J-?BH;2h^JZU1qUb(%4(wt^CIq0M$94Uy76N-h6fU zT3M`?M|Muc9_nEE-yb{8zw!?07whpiUm7&z>%0@JsL}Z5_oL@-$xeJlRE~cFYOaan zOR8A3cCaK)0iIzlUm?|2Vo7(zZuuqq@kLQjO$)wF`ijjno3mjqXu)CZH?={&u38LGN;$ z0IUsy2tw=EM$6YHR^W;l%SobF%*}hggfWt2mZ|gO8^; zep2zzJNwXLJJCPl|8|^^=t*i-#@q(`X9n7_ExRKIDHQH|25()3^Rhv1x8U;gaM?M0 zOOMbP4Vc$SAh74GosF#HlI(bkQ}U~@jQY~5IAygA{t{pNHNpOR#yqb?IyOeqdEuj* zt~Q2tf1!x57wu%UE$dMQAur{dN=H^w0$g=jDhbk=MW1zY4jR!Utz$1>fm$3p$r&r_ zu%zYGSt3?(bNQ_0l+*!Qt&CpV%^Ghd4U=0*kJMR`BRLLh{596{Pia1OtG|2*Diopt z3e~sxTT=0tRAwEl(_GQ4)+n+k*7_T~0joIW)?Vt5l+RFwueKB@Z6MqtYx1#!Z}Wv) zN&EmoeB(A$?kt^UU8W#~%BvSxL2cFf(nv;kAMeLa*;UgKzeI|3No}aK0$o>B_7ycX z-{e^-@{;nj=L}|>4)O?dbEO0J|5YhXxoB=)w8cD19%vOMoh~|EIM!?>Pqyx;&Cwi@ zs>r`Jf2@{R#mn=RiG{kN5OtJQoYi{GQr6fR`LR{DW}aHXyvN*3I;Z(z)yXOmE4(6H z*GscW(_R{3*3t4vu*eZ=!vdi|8r`cCA9n~$h znPYNy&vm_JENdNDG$M0TVtB&Sw2e95xyd=Z%##g6oMKGxba$=G*_77n4bAP|+rKh4 z(#@5{*VhAz4rhm?y~vC&x?^+LIL1M7*x}%5DrxL%TSNC;^O93N6Qw?yJMt~{p}A$z zmcmsr}(+noNF1A-~zZl0F-`Px-j+uLzPZy;Zh8N9MdZMR?!k5j!s`HeH!p=o} zzoH&gdKHw+JCn63y?jcmvoY}!L8@FI zGLw_@V*AADlJXL&M*se?G+g<-Jh6lOp+;L8`UE_cWOahk&f%zPo%ti}US_4Vvnj(f z9kZ?e+<)`#&5xusW4_l@*JPKjKCJ_Lcpq>Ws5@tlNV^n2EZvEJQ1MFdAC(?GFZFzG z)7-?|_Ub-;w1dv;cks$`*#X<^D#@R0&Nv)*sA3(KV@|!CQ9eJWAiQ8?(Fy66R&t2& zIUShh>1z65`_#C=G0Drw8#8&PWHr{8bG zr-VHV-T89P`xSrpJKbrtyi=zdon6X04lLEK(jK2ktN6Tyna}u;DJ#bN$CZ$pcdkBL z9%(G9r~hWR(BY3ya9LB8obpQoeO+GatJxTI;oACzKNFYzX!*56WV2*lMqF~Er0|@f zc8&e?B`TI!Sz=$wqot;l8C7;x(0!NDQnT!OIavjHh5zIS=S?%O*IuzN=F-QtiBla@ znA5-BzXBY6A2@Heonx(0f9KaN43)cR8(O_nZ)*@*#V18~ zi20M)Io&sLcy!SB2Vc&Az7V-2dzQ^^{dFr>X}NW#-64k#MnBDw!i5D-3tJW>Wa^T( z{n-$`=5NomtlSfYQwz1`Z)#=htJVqDU#y?&o7-=-?`FHqs${FR8r__d-S&E(^ttWZ-fxw!ztLEQ-Y>LVhcotq#b=o-F^`;=l*$&n1!)#Oa z1^OHMe4C@17UpO9ee=#`jmePHYiG90Dv=wK^*Q}x#_#+N{4bY>vQXWYe=V_HJ-$J6!Cw0soWDeGNNezmU@&^=!%jFH@9cGx;8D8kR>DpY@Z`X3JyId!@g?WT{)^uOw za?k0JgSTmeX|}z;ZJy0-n+f`MhHd&Mx=uPzLpQ@Y{aBj=Hpw>CZH8(+wb52r&<|U6 z_YLoD@-(qhh&r!uSB^RJMb7a8&d8Xjq@Rer6*(bpeD)#TW!p=dUHPZen#V1Q-ktm{ zZ${qPH2cKUk-a}{e)A!0bXdZN6TeUXX&U2^V4Yefw#%=msVzs5e)RK_XV>AQ1+qqXU| zajdDIQw!(DuGwy0ZW`xPrah*M4tE{Qrr!37VY+o{(HP{)VsmTlFZ~6Z;?l8#vH5?^ zM%_!}0i&wV)>f2n6&%i8l(jtLREkfMCiO$!cFmxov5B$22L9=v{>kjD*`L2OdGz0u zsOAw(!iqxQyj=dt{k!_T{LlA4OaC16J?h8axT55#sn^rpGEb+!OfhC@GI}SiOwyz{ zraGh-=DOvL&aGH5tZ0)mPW_-4d4s{-nbjr=4yMJcB$> zyH|IA;?mcYY+Kp*z!+$M$!?0Fzs)qO(bjKmZrD7rw$I-cPf(s1=o%5Gu7EjJgc#v{1eR$T1oWzVZiRBaTrM60(_H#gZL_}!B z%Fp&6UWAQ){p?-auj{{CeZT(W?eA_;AELeD9;90rjxAi1(=9DMadMn%?CO{wi4pKr z|EyZ+g{i+X&gIwCj5Kt#9cJUJNwV6fOSbK82(-Sesb_u4W{WOc*HyPtH`~_QRM+X9 zb9a}fE-o(rxa@JZa@=8@Vmr-lguTh(AEe6@eOW_+v9IG#r`paloq9MOcL`?P7aUTI z!|i)Jlr;I;IUDZil5G0wvh>yUOLca-6WT%A?$+5>e%2kdx%xVGVfqI)vu$2#V>OMm zYivU0vw0a=ZF1vtvvYc9eM#Swx<5JdZ~s_*(tjyG<5ol;`SJSG}+g} zh8%ry!*QEJn@~fdoxO3NUSoBYuZ;W2r)&-wUmJrBSLNjeM{|DXx6^E~Ei^r_o|CNo zeLiY^!l;~dB_#J%SO=WOT{imh&NFbxL;f-=tH?`_p=*HAAGdQX8HlP73?+|PU}Y5WLb?i7ZyYm%~3Q`b*n?V zHoE;*5E|Ppb`zI_0)aX%<`JkISE7anE8u#5u%&j(-1Z>gT4P zTSk13`&JMjtyH>YmPjj||J`PRLxx?1_JH+m%_Qr~$}+5q&enzK#Jml6OnjIOlfG|GK<#>EV3O_)3><*kx;D`$oUr zW`Himww}`@=aQz{wqy0v?3&vjvY&1js_$kfv^CpCYIV}Rq7y|&6)UN%c~|a{%tcw( za?a;W$ts?CE<=~TH@RV<}-=UtHUB=sA zv#V@pu={8Vcf8=x!qnB#?9{-qx@m}MgQMR0kaHu)xyEO9Z|v?Giff&e1|4GM(6ZRNkM-d-B|D-sb74jywcY8x&S${@N-u2l!bAgF$?3ir2aAo>WcJMQtABi z1!H7K(_7EKzUzG40t~@@g1iE5`AqYf>+cX8SEgxjKc9-OXPvvdUGvQMa`*o08ST}^ zqnB$jmzA!YJ%{*>2pkkR&~KR6cDJ3blid%w-*Db-|J(MCF5LQvTv72VaLehE6_x%Z z?Q4>A;>-BJgpctxWBbQWjcXd$?|1wgs}SeV_2D_u!(yIA#r`Yk(mN^{MW6DD6$Y4FsOxg|$vFvwQ&X}A*chdgOS< zslRK!XG!0IexH1gdQW#h;P}TlRR6}B)3JqRbMzU388MN7YJHp)TI_z==b7J)iAUnsCmv4yCwFU3?R*DmjLFKUqi;1gA7caU~EXpw8^cqdkJ@)`&rLTzE=Z&1_lv8(l_O4E$F4NrRx>a}U@3zhPkm-%F zxN(kdzLb+6k~KYP+@G~SpZ+ZKyJV!}H_!KrU*?5u2p#xp$h+is+n#1W*zs&k#J1$8 z>Ry|BO4anN%p3Z(J}>>89B1hAG=Wy1l+k$yQ`f}bO|s4^TW~MSE5$KkT*h#akr~wvsg|mSdv*eq4A%ckyC!1 zif9$~@p0o@KWajA*X*CU#SxrI19oO0XaMv3dLXS*LcaN%;8L)HylbNyiD zXu%f!`(i7qHLPkAIMsAVyHB~8`z(I-&%d97!uLge`}-_L{rxaQt z(oDaf+Rzb~qK;NMa`jlpv-@wUucW9|S@o2p!bw)G-6E^>YW%3~i!$+EuN>FfEX=ze z|1x63+XpY+hn)zEczGq%A*}Q7!8!d5%iJn^&kLwlVpOR;rH&Qr>pjYCvGaJ-Z>@V# zX#SIeNlJmW>YVAlF7QdQ6UB>)c^4bvzt+EI;P}AiJ}z!oZ3k*M*bK0L?o!z!(M$Et z_AGM!;V{b3OS39}SDI^7kFeoa#~g6qI&0(aeaii}QQwN5V-7x6++3emp4EI`%QFq1 zR0%3k%RS0!L{jc&_m{I?<-R}qA^6Ss7m2TbMl49#W8KVkXt9Q6{3|}LxV?1uV(q+! zxGGL{>|fe!EvlM5Bfp;VQ?uU~WVC3-%_y`Xt@!ek^5^!&1%S?0qj^gr%2In`I_vA9UPO zEvr?NcKw^RXq-|jr)*Ex=)9-jn>>hkx3|`e$y{+W6F} zaIDV1jZ^EzRea=s($&hr$##tOd1ZcHa88r_F6O6Jadv%NTKOF;S*85@^2f@A1>N%R z=(^d^N?DaRH)mY-ysY?a&-{X-WO?wb&#fpxtTW+ZNIbf@~!Pi$IqeJh%$a>-XZpn`U@pY0OCDzFH zwmIVYs-kO??d>154y=8rbg1`M(;n--%J`z^MZuaGHUo_(jpKE$`mrt_0$P@@QKeO- z%jK^JS$kAAyfR1R`xTrhYN;8e)mSZ;5Awf1zDm2I4!QAh*WOLORCko#=Hu38Z~V(w?%B-+`uZiRU*{q#qQhRN_==c;$Gs?q}nlW|E|fH zq(qpCf+Ff)YGvDUQloLzCkOsx-^Z%4^;7LT-Au>v-WB|3It|ent=gFKeJ+=pP%)?C z-;%kWt&Fp*$`sbjNy&|oTH9LLm9ieIBVbJczP`s(&8r!@8KR=Vf<-wb{n@W!!=?F#!Eo|jA3sItCx z**d!_jtaHBg}pOI)>f6KP$BBw}vtHky#NhKN&x^_-u;Q zWL43+dNr(Yz5a{F>uSUXFY;;Z`=r#|s?}pqf@JEHG)w|WYTHn83wK7NC_na%a`mm%r|*7a!U_fM6eJ`NvMuk7ln>(Z(gF0~1;?_j#2Fsga7u4?jF@A^8S}iAH2S{|2;e|Ixo=q zP*A(l-%8XeUcF3gh3eG~HgIn~z3JlGhH}pS101sCO!GQ*jK-+;&cBkAmL8S-GqHAB zouUl=AJ>(B={_5r8(ELcPLA#PIreePb)TzKu7q7W^8EPsH%W7nH~mRUn_;-@aj5j2 zY7UK-HQiaeedUh6ldbN?zxX*ZHX)%(eD+`K=$_v;y!j_&#?z6{++IY6Z%m${*y$@d zj`y+-EM1~U=|?p}noex_u6a!3gEhC59~Kbfp6l{|9GwMp6v-Aqd)j8E$2BfQ@!+<& zEG)9PEbQX$?(XjHvbe+I?iv!3K-@j!-a7Li|KxCp&U9B-Rad`y_ddU3yreIA8Q+W; znQV7U=PPH0FGBQ~TBqJ7uSG_R8n_MR?=$Xv%y@VG9rt$0`xANH%!STEldJR}ODRtf zOZA81J~i6Zq<>PwMtpp><^^9vcSh}4DzaPT_TO-4LjizYsYBj9d-m^}f;2XNZ^_)M zrmpLB10jkys2`NjyK#2A#yw(t26rxRaj{Xix-X)~1g%sLQ9L8p2!D8|+Mby1TONCU zN178e$<~rflDGImZn5P<(aOy1@29^Jp*p3z zZ@d^bHY8M|CiU1+aTr7cTrx^^UVaRJ$$fBDSGOt7$-iDSvEpFWz{Z7sfz)%8zg#(SRHODvbnP0XR@52g{8INLd!!REH^w7s;|vKE;pSHCP@U6x#W zz0_TPpyF$l-Binb(Dc2=WghLYc>DR!!E^dEDo2_wt$?WRdSVN<0Qcdy@M(BoOpcz$ z`jW$BQ`P_KRv0E3UB+F;d-_t{Yn@wv(zqn(g7LV%P!ptHsM@JIt8AngD;psFPTs`S z$RfTCTZbjMhOFRQ^@PLoz7x7JnQM`qn%p}qMd?$SunU=y|qH?^K1e4k1YZNkw zga)azG{oJZ z2$$1!-QC%J#@*jt?0V!9xkq|}eM{)~Og(Nm-wa%Nbs)n;jyyz@iN7R$q|KxgrC+2E zsAO5DqPF@E?J3s=>YZ`}>Nv0ky~jJvgGIUM%`=M;y{##m&g)T;fJvnuLWF0KA# z#_YWuhn)@FRqo{;-rdlX@2TzW=37gLGZl0ceSnT+yYUzNzah=h7PykE7E_XI;yL1% zVwrTkbSl+EHedEdHd@XrvQ&pOBeV~7W8lfrq5q^QR_u{!sYGdGacwe|$io{E(@A!DPr$EI{7F zp3VZ;UD*J=(+NB5iIC;J0-cTR#goY6;?0tkl5>&^(nGQViZUgkexe?yxuE``Tq8G7 zXT@5wCGi$V@kZEckp^-5FAD^}gWb)HU^L7y`hf3|ceW?bHORTo(bDmoqmyH!W0Z4- zv)D1$ame1${vGZ=g5$Ysw>#0(!1LO@(A~wo*-d%6d)NB}-$&nR?=H_<_Zat7cY9B& zcLe>FxyrWW5PlaQ>>rL?Mb{BZ$!_X4wCjC^Qc*6SC0{A`%kn7|bw#>Es-$*Nx2bAM zBD)}KCO;^fE32VKQ{Sakk|I(M*rly;oJF?5Aw5M3j3JafS?-SDmU#jUpF5xmzknWvuaAJuhpcR5-m3^ ze_Og(R$7eK8`k>v2_TpczN2F>5k3=NN<1g)OIlIa zWKZN@<#*&M^6iQ($_=XS>c{Gv>S5}$s`aX3RZlgmexRABnX7)Od?Ghdj5JK@m!wOQ zBrftJ5sSA+&mwDq>jH0Ct~YC-(EyV)4$VpUk_hh z?>P5j*C~iHRXEZdg^vCX)N#^&!al?P*q-aiciwZ=b2o9v!feyhJ=-1PIpB%+N_tGi1-Xe> zg`=1r4Irxn2mN*Z3xq-ZTXqdIfIjRKd;j*#bhU8~v8P#wS!SDg({a;bQ({J?z7 z>@y#=6j_Ja70!3A@t&XFD7q6ff<4E%`91z>1a5PB54A(kxHkssc>1MEQczfddUt`M=7T?jnYu{sP&YWxV}CGj z>7zcIcd2)wXN~*2%j-&UQSR;Tj-DvbL-zvrQMcOD#FOdy&l~V*neOarG=rDTZs53xktQCwTBfv3ra(zVhM z>L1vPK27+sEyM@n0+E6jW8Xw+kk4KtB=aZuqihfQs;`?b;2GsEb4~Fs^5nZh9iMD^ zSbId<*4y_wP0n)XXLkmSgPoppz7TGKKS3nL_mQ8)LnTSll~iY$O1@XITXjj@ST#=J zl;u*Lq!E%U;^yKtq>X4y$Z->pLpXW??gHl^ciaY<#j_xr>k@JZ3&CFEy@`3GNt`cP zCv73SAs?q0seG@Ts(PvJr>0duG<)><8jo6~u90(859vT?k0#QU;s;~}Iw-Ii@~!Fw z@`R_tTK*t2*Q;^;=}=f+R*$YqDK!<2DqLK+plEVwu<4Qcjb&m*RmBeTd*^)LRbdZi zmmHJ(m9JFW)nAn+&A<>ca%#l6uqDAA4M)^#VUJ({_5paJ?V=Tamfq?*ZvV@^!2X~8 zzT>WEt*;Yvhpz3zeZkBj;XK+$(nPULIZvfjw^MIW6)2XgC+QP|wgoQ_E;T0USEz5R z^s4!)D%Dj@L)}qzzHF{!AvqfIS{{>ABp)P#IENfatV7p9CZJlR#7gi{_-%Zv=(C^W zm(m{3ad&-prqx!Nn13MM{k6l_?%(4w{>a;5KJQrLTu^zaLTbCgyCtiPOJfElhQve2cYOO^Hed6v$Owg~OqC~mu7CM=6^RNBq1P8Z)zU5x&%i#*KGKE?9 zEwpaLiLlziql_(ryup`3zlKc>{}Ga>>!}zeR*B|;XL(29he#1P%DwhI_D=K$xw4#d zJS}Ovcdhq;Q&BA~wq?#rdy#DZdOeNKJzErPel}=*XP->5;Me z1o=gB4AK_;LmI4}Eop(6eICoZ($#r`3!juVF>iHD^SC@a?d!}RZR2SiT_K+nL_{~M z^&z}ZP#g6P?O$ZZ|rZTr5eaa+I}=YonbBv+B-EsEBkL4`|L(P0;I*myPk=eKPCuqOCu#e7p3e z?Z+w4?>%-re(`+q*SM@VW%JNlluWIl#wfDGE;i7#INyF!YoXa6wHJhRp_Z_?<$ift zWkvOdf(99VK23Z3^xgO5Sy=-rTlnI|3f-6xUnmxaN3}>inKZd^L8JS1LP8J8PH|Ig z`Q=NB_Z8JDoRi-zbAD=O^61n9nUhMt*jKlvs(qEM zO{eX3nB~Y@u|oaX@Ly1Z(Wd>TsZ#IObu%Pum#Ipq#iR%;0M5CQC<1FNb;#x_M#_0A zOny)CQkkymt-h>@)7(_ZLIO|z{wJF(gwBkX5Bzs6Mmu)JZPf`uhM4Ig|emvV{$*{udkS9_1Zd`>zfYQ*7!P$2FN=ahD7Ao zY8AUBYF_Y5)n#cj$Wuy0S$@CAVX9YBmbbCMT=HA>TKgex9H~%tP`}VE4LTaWBKlYy zBo5R`j-$h#YB6~}aR{=PHv;F-L@4Hta$0T@H{Y*A$4OSpmnaidW%4J~F=?3i5}AU> z;1AG-fm;5LTqC-jFV%V8o@!lMbEV=yNlVbIt<11|&U=MF#nZ2nvX~}OMdCjG15!fO zQPU@4M&k6kogyw7y6N0(RQB*s7I{a0iVL6Gj&rr}KKeuPz3QBa z9Ic%p&BNzlBPAx~M44S^SktBG-<*eO1>fvxIR(#a{%{yPv)lt*3tZD^1D>uq91FXd^{`3|Pr*({ZNHZ4ul*MEci5?rRr(&f#E|DP zzt!%MV662aj4&k0HW7cI2SkgI?vUNoS=3DkX9_$Ad;%AZ;CKf98G8&_!-~LDt_|Ch zo8w22}q;0j;ajpB^_$v+w-FZuF0x3~O3 z?QgNCiQ{gDuSa#%jEqQ&)5Xr!)`Dk@H{PMS+dl40=~9|yer-;#dQ~zar@@a0g;~xa z=p8IXQeSH~_S0OG{E4`QMaWI*Jl!Kr6!y`wu^1_;TNGV-vSOdv=NRlg#=K{$eZjsR zY;UrKQ4~`Ww;ep*qW3l58q2nxNf)yryOSCKL3PrNgJ1f48;hjb3q5ghCxP8!hQ17uf}1&U*` zp;A6cgoiOugrk!yUnE6asc zTFpQ4wlpcK>$^4zx6rMk7fch&kK)h8^Q;?%Cgh*um-stmQsBJkI9iF6i{#QEZNH!| z`lqtte#SH2cRAoCr%KZ$5Aiag7qgKq=RRWchmK-N15)GjEKpZv>cj7ssv)nV!JJ&vUXRn;O#V>(sjmrWeQ5L!w|6NQfE4AD7 z3EFllmTCsiD$AvJBnjki*kq)G{~Nc138JU_UeOe{nEz8~!sEg~zB(`tYeFs~((s;G zHRQZ!q0i8fqF)A}w7`7-T24>T@UHUZahd)SA&K4T+3RfN{KqvPq8STazq@C-Cp#Zo zwp4K?&azZKpexf4;Llqar;QK7k1J<{9o26mA7RA+bDZKSJ{X?s4=+4gBp zJF}U)hQ5KU^M1lbZVMJkMxhh&jnY$+-O?4H2VEmM7zjn%i91N1$sg**s_Nq^*f z!b@%#KPhk=X%grueB*ksO<6Oq4NUT{5CU2p@}2!FLi@h)aZ>Op(CWQHPorS5NM0+3lu{2E+xbivZ0E`NdH)<=M{$BMjdEhB+Z`@ zSO~|8!OCSdT2FKui2=&i7dS2i0V(k@5bBrn%{Z1-bG`YE{yRXLDn%T!0o=!CWVmQN z_64uN+u*&iW~cz1Bp-N+j0Pw2b5S(RWWRi=4S@=3g(hM;ycfO>BQOpvK|4c6=3=5P zF#=zTMPc=^Bus}{(OKBv@Qm68Z-TW&*Fl_S7o6b$VjbQRZ$~sC+mWNl2=VV?FS&`l zKprIXiI&6!q6hH??~l*HC*!T~=a>dlU^!R<-WF?$ZU8685z#92GS(gcj_pA|p*?XP zKY~BSPGP69q4)%%0l9#{_8MI&>IiK0GN6&(1!y46_XY+a!I1Z}3;7oq zw{+xJ)pQ=n)S~}eacT$F0;x5^-`ig$ycZIoj#`@kVc;|35Zy$RIR@=xg8rHURMZ3@sMJu`y$`k-k0V_8 zd^{U_fn{N_cq2Ry`+{-Ub?g9SD`x>m$X6S_2Q^p1vQ9 z7&bVrHRufRf2;;Z$peH>3|!Ara2S0AN`Dd1X?9>|eg}tknJ5-*hYm;2pe8g5y9qXs z$)ZWfAs`*E1(LuKGYpIp1HcFI7HHCu;NUrhECG%v0(>A3|8Lpe1+|n5ZJa~Lfxi(s9wJcm0%HA7J3p@! zoWTUBj5SqM4cu`kP)aS((Wn);^rMgsuoT*OF!Wj+Sf6BnSHDZxD^&8yu!9f^wP$7s z^@S!vUm;i6?e73A$S7dLJ|lxgDL{{{fd8MO{}AbTKJqshf}DXB&>J$q9`g?Q8_4wY z&`%x!FE{Y>Z8>_rfmVbt`~6LdSKdK0`?g6owGuQvhAEmgoV=LXgTt(1UN2WCQSdk4fdCLtApWhCzIz2QYr+$TCq77-M>(J)l0>Y3S)~ z(0S+|XajT~%m-C4KMcZD*cWKO)!1&VIm{Amv76Xc7?+9I2lNPvxEiX<+Oe z$7W+L7-L!JbCgD#!uO5X8T21?3@~ip;MEFI2KB&*-v}{f8feEw@Gje7PU#CynUA9G z=tFoPDg0GWtO3kF-LWZHM=Sx0#=@~8bOBUY`v{}(H<1Ufg9#ixcY)a71|#e}@U7J_ zhU!CW&Ig__1(@qyz_LCEI?^GV`3t3=}F*JybEMvYnTa#gR9^f(2ZY#_ge{+ z;#{!L--BMS5|u)$)B$So8TeyfK|789zbU@9=ne4j1&BhF0fvE&@IIFj8_da3B0pG? zegjg^38tAn;PhDmwFSz+M#BMLp8+feao46mK*TqpyJb`$cnJ0h+Nd z*aXhQD>nkWOedHL+CpUPm&JA@^sHaD|G&WYwi?c+1da+r^T5ZGAyT0)zz?wi*!?3w zYo3I&@dH=+80@Df;Hb|+FFpdt&VsSQ0~}&ApboR)4!8`)M-lL&g-{QzG0fwMFz$v1a$^5 z=ES3pru zhu;0GZp%>U=fCX3g961+bL=iK+&$rpUch&8m{pF0g#`n4ek*Xv7r^jw7ykV+ROUf_ zl=?6tt?+p?jIm$V$bmo{zJQT)6o~9!rnTFFp8xM92_h5lr4p#*+6VaAGH^wy!S8e& z+B**j=Yw!m2v{6!(EHMXg-2k8lL%z#FIM~t5ct2CXFcrBEQQxy3G=xm@B@4(_uz~? z{+2+7ZUFa)9GK%yz=?i?V_Xd%8~@LOPlNdy9>o12;AD9UbS4GdPb}0b>*((emAK-e zQrSLW;Ew>Soa4uV<}ZNR?FaZ^p2JryVTG{(s&Q=s3jPHU{!zfh{;KAg2zG!5fnW7Y z3xVPl*at3JUk#Ma0f#T4DbtVb8+1=2_Pr&uF3ckA$)=9&GY~2mK^DnzF_J0>r6tvtN z=(*M4EIAB4s25PVabWrT35@RwFttR2%b*QVyjnQ_snBMt;WgUAdtQNdRv<25%m=~Q zxZ$kKKx&@?>i-(l8~Xz)lZ61Wy$;$T6g-`Y;PH{c{k%7{@F}?ZzbvlzVO-S$`}R+0 zsVexrUtXGYaOV8Fe|r2fu>8z|-mw)%TLO@#2B7<=15ph5BESmAK*cR(APzo$@x^km z1$x0Ka0_Vm0$9n#fETXv|5luAuq1>+8+V4eJ0{?Sw)|D;?hI7JI|=@R{9mgQXxklN zL)Zs@UkI(80uJzqKr{I5C^+J9IJ*PTzFWX0`pXOS5q_EvznKDlhVF2MLV%+W1?OB> zFh?wdW7`M)Y7^YYTSII0hH-ElTJbix2L{5|Qn;Q?!Cr70erp(;lT zck(7Uv-w~`Y7g!FAs`3G<`0+)ewlQ4!WcLJ*Ek$}CB4BwmjEpv2{X=Rm<7IoM>HUg#zoZ0V7po zIHD}@5qyWfRRDeG9<;+TxFcMIe>Y&&^vfVc$L5<}6SXdc|@ zj*5QyBKrY3eO(BLig=3x^Wj|_LO$MG2lEQTYX7k-LfLl(#- zE?z7C4r`O)NTz5q_81?CX|YnMA^Hz?26}v)h(o`Ck*6{GN)+#3A83!ALYF}gm;|%x z5|IR(Bh&~HXj{bUWo#*=SvLpm>6*(mfz?YN{uVdIUr(ql+J9xJkT4^b?-cZsL{5f$AFn49s{L|c3&`Gy!QS`CbNBXT+Ha2Q2zvF?G3>>^|-cYw1L?hUmi= zaI3K6qIkB_FT*j>X|@P$EZWE45&j}hh+m58`zfi0Y6vT|A><{zHClwO!rT~!S%mli zjnw1!`%BojqF4ASpHKJ@=)!45z1Y2siH=1J@P6DPq=Tr8wF%uM$B9|&O6G`IA^DxJ z#TOz`k|ShuVAR3?EcwDN<(uLW{22BxSe44zqChk32-}IB3v^a(<{ke~x=VcB|3vhP zJb^=PV`P%(fq$@QkvPV0p}+C(@m%y3@_{p=ThQG?3#>P{Sy(5Uh_>NRiRN>=14=BK zwhA-R3idN_|J(T!0hzy(h@cnH%|uNJ9$38pI0QUN$1ywpoA@NYLwJiHA^#@U;!VgY z{!5@+9EAP|49Dkj7yX<0-!Z8A%WbFaeoU0Z&t$Hk`>;g+GsY&`f%Zfy+1sLlew1wo z)IR0^D2zux3wPPyAR=>$%NNp7To{cUML+ug;`?Bi`49APKS_1KQs_bcLh&Z1g1d&b zL$c9vh>3fJ++`p5Lr{azQ+Vp9+1~_~d4jYi?zu*~M#>W8PrRqS`-nZ_)yO{2-~PK~ zF+NoI$sXfbQ7X2~<72zxs{^my5uA$@B&EU-yPGS+%(#kQ$Gj8zk=2--mwG}2l(eU) zvA2#N?uUGyXb--M+sm#&0Nm#GaBumSf|zcMWRXpfX08X`4HA|t<7#|Px~=FDa-E3s zeX{QoJrG|JI+{n?>-h`BoA4Dt-^NP_!)7Rz)5DWPsYB-&BT39P4DHqz$)t~ZU0qp!rzoT>@Ntc_MmnIohciJQtk=*kEiF-OQ`RK8mYE-Pu)K z9lRNlgcSLTd_(+Wi41-!6XlyEvY|WK^RA)ZR_Fxum6xE83JU|*xS8A(PE7an)5r(x z6I<8!6XwFU!W+>pOv2ygLq%zUQ|J!<0~g9a!#852s2w|-nFs!s^T%rFR zPje#v4?Y_F0aI9%|C{X~@&}^a4sU4SFQgfNl_?Ksv9Ap0?#a(a{^Df*V&*1yh=wy&jpKjQ57vS_!Mkb@(>s2C%e!0uTM2>0vA- z+9#}~pL*^5c!;kIV@`;+@Q7!mcRDjDFc~)md=8oO0Amhp77jAqJma}{>_{)?t( zqih1(#nZ}{%We{?{SZ-fR52mw7c^aHW1nx|kFx4%L>*VL?IE8DYuhsD(<_;9zRWvV zh=X{o)vI-P^R5yQyn#^azUq6!#&B=EcbU%oP(IshVU9C1nBB~9;S9rZ^LbP_Q8n(@TK~_Y!O|??5A_- z-M(4OPx=GhnMw7?Jr}r9foOh#;~&cic9zJ;lP;OFO&|=t&a`wExzy}+EC(OK#=3j^ ze)7sd0=tL3z&{k~unT=>m|gxx!UJXweTn(WJmiAd8@}H>?}ajC0-x{cNN4*xcr}ch zdl{I+wDgVg&125{{$c-NQkg-pCtcxJb2e`s$hp$7L-_Whbe8lwy~o)Wfm_HBp~#!) z%Vz!y+z<-b0FIecaQG-3qIf=}a24oK0gp z@Jas8%q?eKPo{81v{snGEb*-IPGQF|7x>G8gQEWIdhbnF8U2aPVKcZqZXjC&XTx}# zF_*lJy_K|$b@`8TbG#16d-pyzNVqJ_VOx6py9aw)_{Q-neGl@Hfe8O7hzs39 z1_T0Nc%3F}^dEzDsV$HUehj()2BHi2`FebC;2QP`xx$Tx9k4uR7NW1dqFPFch@3!gkzh1v@Oz_ zV2gEyd7pdYT}jq$rstMxj&<&dz7XFI_d!Rxt<3qu^9t+-(OrnM z#A7QlH$20op~G>xv^O|1PKjIM3Z7*39qHlp&WJ5)O=O7aSS zh7$M#a;T&jDCXlN1qG=Z`|qJ1d=1_P8;)?oHn@Am`djgfgy|4ns_pN|-C=ytrmI;a zpDB=ni$3UX@b#^D4X4yskV~8eDp! zxKTmdyk7Y&i$52=0H&3wJW}zl=AgB$!|HC%ze45sQ1X-5rg)-}8HnJTpz4r+!sbNu z4;O{yhSmwQhnxwS81go_ouP*&SBWY&E92D#%7?NOWDTw+5=l4YPLz@#aT;}qe#2?v z2)UKCkypho$xAAOnjq0h=1|?`bLF(`3>7M!AU;9N#G=7VcpmuG{%nALLx+3f?TbvU z&4;YBZLiJ!YdTfUt0*b2U;cM_?TYg9<>fO=D~cKw+l#sv-zz>=`nGI9#coSI?_?;k zyv-VK|Bbfba#>qtnIcBLjcgY8W#0DZp(Wt-%euJ-``-*F>tEuOvXPwXHTkGrZJMB`s zaCdL|8+V1j#?OH4jZKh$$~Y$3*V|uPMV3JoS;ek`Lxr!aG50V|aHm#(F0Cp5=1dYc zOEjv3nuUsT@cQJVw9*~?S47va)yA$yi*BCghkTQ)ND-p$t1gzk7b~!5fg0|+?@#Yn z8sT%GRkwIsx^=FdF0;h>vf-k(!$n7ve8fDRwnMM!zJ!Zm>BDb zu8DjQc_8|0)SQTJK{|DBDn!~@c3I(-2irPvJRoI}!qMqsG`R ztOETPYRV^Ik*F!~l4Ix{o)(TM%gdVRDnqF)zhm~GAMV0)*4bPXea$?uv~S66(>wP| z^0=XA%#5ga`p=R_#1i%O@WOaLp^%&KBfUcp>DZ21Jv1gQt_# zYl<*eTN}GFd~MjJ%zB^SJ=8asE5SNa87jMOe$d#kff0jib%-4h|0=FtTuA)m_~6=2 zYxD74YrPF~8XxImwbM0SG*i@nDI3T_B#XuWi6=-NlNSj!*;yPfeJ1$`dqL6CXVNgp zb9BoWE9Wcj$qq|DOD9mR#VTyMU&V#_UO6Kzd6lz@b$LU6W~OfZzPB*d*-xb4lI%65 z&WgI`aqLgor|6jtBO9dD@@O}xM;d`sOIVd?O$d)qsP%7fXI&dbq-;Dsk{{~6ZJ*(i z(1VzG_b{s+atf^F^J*lnSKI(xA~uP)lB1|&YDv)Pu$qXDwf4mxO1N12LEWcy*VKL+ z7ZF()Tw^#I1bUBPlYXUoJoOl30_*%!L?@hQuMrm#Z^;qTgR&vA{o?te=3F0NsIL#*nt8=`7giwSu}_3rx?jFT zwN`aYaas0JHd4_|T}wC5U^5Od?$Q0L3;|QibsWJ`kf%a8JK7uL>S*s`Ibdp7Wh?nN z?^ee16#C2c&n>@9{!#^*X`S-aRY~sC+!@%X91BLX+k{fS)rd#?6Rb;xT3#vC**u9*jwJvVt0*YzXg2gX1otsoBB`r zi)Q{{Xc6QHeiBwM@^sV{5KzCU)4x?tYj|U9l6BGe6w7Q9X{)N z>ofDh8h6Fg;+DCuGCHMHd~Non*=O&kurGP(cMHc^6Pcrdhp3)dCO#m+D5**pyfbP* zoIB=QbnnQ6k$WTBhobuX@{hz2|01u(-r8!nI_#V+!^~G+E*)1qy7<2mt~{y6Z~N0T zn%&{=hOLx*Q}oo9=<;=h?myiI{VBs4R~96gd(xz+kZdN! zWVy1NiW}-U6}lmGbMR=rQWYa!tv=Gw$#~OvNxxlvL)Jy? z!<*o}@!5DEd>hQ}P=^zuu8Wa)bOv})s`2^c2gz(%Kc!iHPdimNNk?njXemt{b&@Jc zIbVKD+Dd!}zYV*b69ZNulpo8i^|W`KvwSg$Oj~L^RUIlulr}E>o)eY%BVGPIEOlS% z&mWBoL(DUM^O1YxKU53I{tc67D6VR61zid25&ktSGW2GMA>?w9+91{>%P!#-uCeQP zOHR!w(>QSEG%>Y+_N%CnRJN(Q3H2%RaoR7M zeURn7Sfy0#rH)CCi{;`tvW?*#xAnA9E9mZWtXZK)j2iGYuKvsm9;BQly9qOTD{UT+&RvBklDhgA|~=r z#YgQvW6$6%!G@sk`eWM58j*IPcC_}mx?0glR#$pRe2UB=G6?GbyTPkq_n0GSF+q)! zHIx_1dn>9HCzSQn+qF;h7X3%<7iB(`EDj~x5ts1kcnBT~r0a^nIKK=memD7Z>`C7x zcQ+uj=G%MN=h%8#57i7RpIkH~cU-0+Bjo#-lvk;TvRagXa_&nSCRdcJgW z`LOEVHjP)t=OQPG7t~huLBsrzf5SF}SB7;9{Sx%XfEqd)dK)5iQL0&#N}Prd!uFwu zP#hbAHOBU%^P$J)66Yj`srj;G*#Y?{MTBarx}i2tcU0d?hp9iyP{|cyHQY7FLaNwf z+z9rQN+dY&6f!?Lv86tjyNT9dRCq&(f-59$s>QY#2aJ+tu3YTh8J2TF4x_VLB^3uT- zAF7+1jMXp8h8OM0PstB0uBi-iY-YQnKggld6xkfrJpHraOJVgR7DX%#-yeE7*c((o zxKps*P**)vvJM%^r7}ag3;q`9JA5ROj$IL51^25>q$S!)&QWgp3l*k2W!Mqq4-y59 z)EB7cP<_co_&%bAWD*rCt4rmGN8`hhdti`h#WZ1hGamXAJ&kTjGrlBPTWxc`vgoU? zmzwg!a{FeD`*A1zMOwei8zti$a|5yBrP6>nA6o&te)WkmB^J6U=0JQ^twUiReJAx~ zIW4}4i35ISt-I3ty{4wRooSkRbMJ#uc!ZP*O@(9rlgMcFCb~vs^~du6 z(FeVQ-J@Jr9FJ^$%rRBN%Kj=U$<56?nx2(9F72=En-y<7lMpG{gUA=f`48|rk=|0Z z;cOHdUoUQ1WYeINn%W8pbw#`cXONpbLPxt@_Q|$Qmgm);($D$ooT8j>g>2vT86`oI;2$A{VNnr>!>dCgjC-}Q z>Xxcn$}h5El2rT$_%*-zTL;vzTTm=?;7&8E*vrBM^c?v@vPhCb4ksdr*W@c{m3*X1 zr(Ul5sTd?bEITVZF3Y1{NH0oW0$Xv1Tuao&wgq_B=dE;AIfmNLTSHBkDq0p_%ANUh zcG`~b3x2dK$f#-WX-jK7E9@U^J6tpuA-x*3CnhGoFzSIZUG-U6pj{9=CM+oo4}Ps# zBHfErGjeZJm&LZvL|3k=SW@}5YD!hFikPxGrB}<+s*-K@eFgq0*fp3L!-(6`)|!{W z7b0dxjfrd>ekSB)P>N9zv@mF%@x1Poda81ne5O=S{Ed)8Dnz%WupjXZ45>%?DAr29 zU}FNA*cov=hWTQA5?o(-Bl};1@ahKg0!#Xw75A*U}!+Y zHt{^Lue79_Ce%EwY*7-Kdpcu9>OZM}{#aS$HV^R3bDpjlUOuk+pt}=xTr({ENK|ZC zjLxfQq&gcE7c(VsbKHaQT(G1}mkh_tMI_Q+*h#N&wzCYb2~_DT*OiVgnpM=j>_l~p zb+9u^(yCZ@^5)H?-SAYwdu3Swi z#UvrdXQ3Yg3H&wslJ|l4Z(7c+5@z}53kSIX`-ty=T*miHzENV?Woi#KSoWLZrfRF^ zie{<0OzBiyQe2T=r%s4*T!pZ_mX-Q8x+<-*8cW%rqSV~)ng3)Yr3q<(5HuJU{m@e^NlsL zOW68crEm+1z~z%lNrq$#Sg&|_Kq*)MsqO{q{8sYTGMi)@*&W-9tn)A8o#W4r$`)lhe}2vDS-zs`b+I@5Q}(9PU`HCChc}SwB~I)$G6#R73=4l5JGa)o z;Cq@1)k=*`J3sx&n^G*Pw0K#u@LLpZgBz;m|v4t*{^JA$^FvXm0QhQ>1pB^vSZ@|n%y$kn6FoxP<4wsGlE>6k`C3><9hUc#UzAT) z+*Rbs9n?I@WFi9HFFa(Pc}koW_H;`NxWdsz>A4HDmjC?r<3rBp(!Z+umUPWp@guh2 zgjq{(hWmgm5Xar52MYZpOAYSun6PQu`m$uIT*+$}1B2dJ`%H0ztc7^F0CSO9z}S5k zJiA<>_IEX|@{eU3E03A1mV1^=(}tS)HGM5{&K+h9$!<90m^tP2-3sjDGQB697p#P-Wi@3QVMT1GYS8ku;`${A%e&Wf zvgw?WF12f#YlnLdJnwgCohsWdXXUSfR63-+XViqu3fiHK0glrzZzeCI5bCRh zft@0XpT&3Z$M^%BgRb&jrQb6DF#j=Oe14!1?Ap5^gJl|K#%|yn$;Z;GvTKx|jK{Cz z<6wVo4ptpFAav$*z7Fowz$)A}wJ5(;a3=RhZsXka%^?koeL|7II%bb+ zkHubjr_2N6SX^_cBDnZoeo=nkqPAsqtN#W@WT~~xcE<6WC!K4GRZA*mVTxOw7 zHAonhp{v5?h4eKv)GpA})0U{k%B`|F(k|poQLb=7ume4HmP=q3)4wsT`#M z7}$e9jnqctaEweNzlvu`2U3W%giIn=ieF3q6DQyq$VC4h&gbJ?^PP8WQKp>I<@u{~ zW@krb`qC$5k;V5apXVi|4E_2?&Rvt!9qt+B8RG5gY36M&G^Z4X>d=nCg1VMssU|mw z4SgMw6x=GLo$-ObX|hZg6e&xP2PF4~j^eD@!$v^+`r6Fk-f#cA?k9qav^{WspX9McqvmqIe8U-zo9m zczD3e)-X?DcXb?n#P=6Hj@ia!!;XrM?}W6%w&45lxkMGYRI&kz-@TImtC%8RNKKPg zN-s!#WE=FW{}f-wMtRRU81s^<{UzCXY<6namz*?9%sN)8ss79B2{Rr;uMgk`B~AoDcPny4+|rOV|L)mQYZf=`8wh%`o@ zjQkq9Feu(Q!T7K7o&JLQkUT-U9lsgy@@0GrVJA0@Ug@R1&%LueD?JbCj{f0jZR`%@ z;=RSMNXE-Is`51k{d~Pu+e5QSGeO%~TcuvFoFwl--6w~L!gz#P;7jr>am1P;%QEvH zXWjpCAw4uLkW%kk)09!!c=@dA-Nh4g8x%dVe?xXtCzT_#t@W6GwSH#st;i}XJg_CB?hSON1$1lo%(HLWW7l%JA!A^)F(s|Ch_p?R-zrsj4nNGipvKUxpD zXR(6=uh3>hfw;SDmHLa33eSmH6W%uDz2TK^px$DbX}qY9P+z3dh=%BiK&jALxXwrM zPWBx{qT70x`%W>@yfbhc8&AB%6<``_Lv2+4ihgA28yXP9NPT_%0R2cqb3(U!lHfy%Klzz?*X2-I)a8&f0xQ>F>eA8{yhARssgTyPOKV_A&EfgZj z0DrcDD$E~l+*S#O#ENXl3 zd5uK5S20-8K~YDUDze#8SO+(A5C-UvC9?9WLG6K@SRAGKTPx)D#Lf-emh?)8KNhaA=+-J3KQaV?k=hw{gjV$!v>Q|-6+0DEHQS9VgQB!|War58Koj3NOMJzF($VExD+g3} zwAONtb|yOxx?Az@$;;|Wqc?nM?2iOrd`X-;c3G`^;ezp&wzk?MA4+WgA4g{y7uEK^ z?b*Gjn{H4LL{SmDySod!;~ZO$-K{@6j(O}3EKF=f!~#J&hw0w4bDs5oKECjcHRF!8 zzU#iPi;tl1QmLMHt`DwoAIl@iMSP`9qkRlMc1Uy^m$>Qf}NU%vMz)l2IgAonZ3?(*sgO#*b^<=OrMO^=BGBNXRUue zIZYm?ITmy(ykDd?Vs=QEz-)bY-A>hg`F8n6nTaw)I)TQ#}k}bq4j6~x7huJ}NBi~!9l8R;) zvtwAVZ@GJ;BhK>N*tG6M?M#EA_FQFf@uOeaUq*gVXGdG=Lszt|OELvL3SHDVpn12b z0CjKWmB7COtJIN-&iFFxq=J^84PVLpoZk%!BPy~TmqiB&i}bDRsl<&{6EEZu`pkfZ z>Q_=DIhss^lH8}OUluO^J)~q_^&oSmdnR9itRkPtXq8fTLQ|vUR6Trf6 zXX@t!y$?PUdOX|{`Y~Xl;yCdb?jqX9cjxx7ow*zARJJv<9z4&-dt7deYDDkxU3YG< zcns%jJr%J<+}9}|re<||JwH3jJt$t;Q`s)D;jP4tX_akrqpK9_;gO)V<$)duBYYF8 zE`8qj@^KdN*^~F>ho@wn^Ix$ow4~vOqt{NxkSMy^kWnv-z#NL0#aV z<-b1V8_KSO^CT6qP(3tceN4Cbc5&B&9!T0j^Mz9&?HJAV0V~h$zKx#6F3QQcE_$^B zBwZcUA*N|udQ8J;ZPbwvt?DL`haSZe@peR}WRA+B%L_OecsHbZI3GSV>`~~x(9^+v zbSq_bs8_h`<6OUN?`?lwlek5B?y>K)#&a0gDXGUTA@AvsJdB(RL zpPhYit)@-Ls$R>w7c@~Nf9l}ut7|?$*V(hkEL7%JBvr1cYE(4)7S_zQ5k;Z*FaZt~UFAp8u&Gxr{-=1l) zHEv``hH{MmX3gr{E?FP5mVPZL3#+%AjMf-$w(uC+O$Opg*kNL~e7jB(>IoUD*v#8) zcEgUU`6aer`QKH4=9)d+WAdFkIrv3bNi>qsEQyG_6Es70Rn}P6i!6d3aZ7~7aAQC& z2$RfHCF$dXS4BLEo)>Kk-=+7;R$yQFg|5fO83xLD)!u;W#ovaL$;;A-(*BYaVw70y zKj=N+Y-gWq4KwU79r9~p-m=_r*+1TuJga}kmil!2yT0tXySXl9V7o~@`ZfEmh_Zwf zSaN%RNy=;Tz0a4(k7vHfO2^o;TrST;s;dZ)$YT?l-$)}{)WuhYI8}?$Kel=K32&D@ zpZ+G|TSU3fh&n=;a4b-ft~x7Mli}!ev=~=O5@ct@BmMulo|ud4SJbpF{qg%@VPvJu zA%*TMdxVUPPKzJjxI?pyWPS8d?Gec({2r9fO<;Dg?L-gJ2gDcY9A$=fn{KUsdEnll zfS{@R&+0r`9I?}H_SV}!SU=ilde*T9D1^k7Yjp$la?K-|AN!A;?kF{Kb!1JSvPMO= ze4?QHul~8G-fenT`c`C~8Qrz>wGQ2rt*IH^yY{YY(g|Br6ZT_e?y@i9Tyt)d>;<3a z74$Y|u{wSLSI-Y6#|1S@_>@xBWO<|b=nlbmRcSC|dh-3%8}e1r`vpIbl>cXVXWPVQ z$+Y@7{lCh$Vj6Hja?pw3q@|VM^-i=KtmCZhjUB6AmW`@9WuDD6BbR9}gx5E`&~#f` zb*qA=|3)2DPbKpFV}1W~w6L{ys<_oy6Zs*{x4_iUYa#yz*9I2q1$7@qU)eB80vU(6 zIk)c%H47x+Uh#I(4>U*OQjO4M0`i#>Tj0}KlB)|!r~kh3E%r;-oa~&$?60r?erC%D zK8(~Q9Yig%n^`-)>p#8o%orEnqWa&0WqAv8|9l}nr+*xld#-mZ={xJu9x{ShBunC4ET#t||Uh`_nDpPxa4&uLrV#zVV41h;M@5 z@dw$SzVEKZjx9F9yur|+wr|}f;{)dgzPGe0XmK2$ytZY_HXBmACu|BSRE$ExyyHyM z>V}x~9-IFZnWRbyQbf*(`4F8Hxg_MPPO9Rh*T^hvKiooS&u$0Z;6Ypj4>IMD9h)lo zBD*AOFTQ|o<6Mq?bw^5}g7n;Z@5g2ReBSkG`wZp#8O}D1bX_lXFsB+j4;W(a*CM4O zzRfhVr1IDB@1;5MpC^B4{^e`w-!>bg1k5iEnkAD*xEk+CEpJ-WNZa65R9bK_mT4UQ z{oFgvo0@l3U$j3D|7l+D6gq0}gq{h*1N&>5DXPU|u>-;@#_hT2JnvMv#yQeVLu!dy zwdsOmC36-Xty~oHI=+8eZimM094UOv4lS7ac~%*Xm0PQhng=pHNslHXtSIJAyfxua zgGrHdgD8y|oM#z`N@yb2g}zO0!X_?`!E=k9s>nZwH8{al-K z@=edDxp#ZsT9et*UB|K?9y#(Rla!LxtIZfaD-#^VM`0=E8d_5+3l>>vv zguM$c2-u?;s|0g0s578xcl6A4&+zPYYRy%KP{4P*?QX&Jf$C)6!wxq3ma@0Sgcb#j z=S37Mn;`4F697|sOMQyB21%2fv_+vAaSa>&8~-_GZrBF>TO}YcN={)%L|xeNRBLas zw+ie}ixCYnvXW)Eo`Q+pXv{v+d+w z%|?tLk=*{Crp$e;68#gIUzt~&tNqwKM^)U*SquA|TK;o`m0jf7leqV}0p zrE|Xz%OT#Ke4Kc#`a;1yrY1A|ug+7tsya;VojvUK@CO~}fFIs%Ri{fv|G8b5_-j)( zoqec~Gy|59rw_H13lb|L|7m7zd!fzymf6iJlP<(`mP;K)g|*;*kNo;4|9tU{^4X@j z{;%qskfgA@-~mAggD;0X(a)FbvEJ-Y=Of^_8f9Enw_s4#pJwg1@l!VV})TjBR z)?ZsMYDz_}Q&4_|yQ67qZO7_>dWCx*)LAxGw;^nI?8-*kCa0T>YCJ!li+U3{QQiQV zK=H;-RW~YT)acFc+!n4WIzV<$U8!y_>j_itb+tYI9Q!^oJM(q($2C`Db?wzN+6iA6&UH=V55!eSAwpF7ZI)$0@}fN;(Zs*&D2dQ#?V| z>(*0_B%`^w!;g^#56g$ueynv?y{>KH3=zMJ`qgY)tIlae+OL-Ln(dETPlVVfmFj-? zE=nwGR=%zjt{QG}_;$j*2%Y4tJX0GQ`aUW*?o`b2kk_gTvO6@3dS;2OTT!#iI!1_< zwF&GPb+cjPMmY(#gxRt8gR=!1(4(V*y*r%GY*$S~YZCta_!^%T zo3T8zTV7}PY~|tT6OE20&rY*8PfMr|oQDzY&QL zUsQJst_i&xdL#I9zKt-tXf_lEzr#+_~r|zbFDh(rYkVc~OfE-=m$pF+z%9ZT=VoR|; zu%EJDwH>slJ9gWfSrqjbDmoUxpA~QO-W|zZQrzFN18Ew|1)yVgWn(A%fndKb{j5%N3=jc?{3*1UF&s|>vU(2YS-r57zt*V+- z8p}JEC05i`T{CZ@P`@A`^h72{_&<+mYb2BiM!rBPs0Ur;-{k$znc$GSMtcrXp=>sA zXmQ|~-B4Q6|#o9m%I1f9V-{QyMRTw3Pf%EUO#;a?k zi&7m{+>mXM6-#3!fY?C{Cc8+z^0V?o(k;?R`EU7SFcIoOa`*u9rDT-6N`64rKoUx} zB|f67AS1{(tr8wUZ3&fRgIFgvlZ(k7#6!@OZzq}z8j#CG6X7plQuZF~D=#t==?Gs- zPgh$@OJCanTfX&$v8MWNIbGbLf~$I3om#!HYEs=&Q;7AAt%)nxyV`r!cZ3TUw(vih zZ|n=uyS~R)Flu@t)y8wgMYwi2^PMj}TDFNujo!i^5?{$x;!R>wq9?=9LJ4+)NUhp2_8aJ6e&s=9G_!3?BEwu5DIon)jdT$(HG8%6f|1#wo zcNmT4Jo_HUbf>`;EM7)#C(e^UrJt0O)Wz!mw5xPlO?SmFX%|_R^u8Dpcacn#4wPI6&Yi826v<@i zUT|OQ4uc0Y+>s#inby;b=tMrwGaH-1KziQPrD2} zK_^phyuYcrJ|p#rR`Fd#&7ni0ZpdM5BvFp{5Ql&SlZDuiZ^uG_8T3B5fBr*6;QpCK zlHv)H!SZ!#t$v_x1~?Hosti?GO@jYo8lRyYazT2CMxu>-QjBlrUTX&!C(b z_{84{4!)DWH#jdG2D^lF9LM+MH!#mx4S$ANNo}PTdY(B)ft*pL;+;6u>8(0%_v(M$M0qzC3jE@Ch7)5II>G}Z$?DO&0G@K*o<@+otXZRB0$D)vTs z-?=7HFf)kSLN(@R`ReJ0;v9UHrI#gO#6i^sEwJ% zwSWwK1l61l=Gyy|t_SwRp3l^KV5uk(W`i!(T&_KIQlR*;FpOvgo{QyL!JYj5nZrIO zg!{wU4Bi7+W$B`ASQOxp&c#+^&9N2882>1!Gwc-Y^FQR0fzj3ljl{koG}Ib9j2t6K zsewo$vc-j%5AZs#psBC|9f8H5jlh|0TP()EN>CC1p*k!9-$Ygu({V4fNfJ*SVuH9N z1jKQfdERND<#Ed$Ky`zEBkh?r-r2|$e3b8jYZ#K|&vs||?s`+Xj>2e1D`pmYjf$iE z!ZIv_KJ7CA5-2OkeVN_^{7GM}Sw%hey|(kdah^w>Tij$%dp<{$Pi+AC+89N(;+rkklsj0LCINc4B+I zezvC!jtG#5ppuY_U#8{}@}4QG}g<+}M9@+)$g^JCy1gL;!g z72V;R$O*hD`i$2Dddw@s08?0~w92 z%r5?&+QG3P`~<=Z3PeqVvFG{uf}xFA@a`14Sq3x9n!=99JWH0NT}=&(<)7uHGyu zYU~ubGW;#EH;zTzTj_nGr5PDRdw9NA4_4AX|6LeKJdV1h+FXacsNHVU!MRQ3)%*){P0 z2qShA{|&n2)X>GqSJys4i!^j*fWGP@>ngt#N&vj~HT+LO;@itCBs#ezR6}wpzt8)e zE%O(#TA@F?)_W1V;Un$a#N+5>VL5gbaB@TNRO&J##*ZWOy!Xfzz}PyFT;zfA<4{L5 zlPe~Uvn}}^#4Z*j`rvQa4L$(L5N-(*)T2Eu%v%CBatpk6&|RL1kAV76dTyVz*1AHNC8=>u z6d6#$drZ_FnQC=&|De_OE~3wH9kmObAK931hO^j3csk&sE)vb971Z!}Z?P2OHFXhtf{pQvmL}pAp69p?c@FIp<}-DozmcO{u6I2CU3}NK z4;{eeqN|a?R1b62yz-^@mLbtf}z9U2k`;wlA zj`b~OJ3^_z8`Bj(?MbHlfsXb#x|n|nnusJc1cullya~IQ`elGmOmh!~1tr0_Fzve$|%6y5gD=M}<4 zH%JZAaXvY@lX>WV$tQ{Lo39dE=~L7+>3C`fe-Pdx%=YhQ`r@12qX1dxB$ChVgYWw< zVKtsaelE6`+D}a5uJCP%Ok|;VFSZjt?*9$$APpTMnkpaaOZ0z~a%>0D7(4;ScotqQ z9zq8~JJF%o32q=X2({otc_X?=yoPQgHL-xqh`r_CL%-nl{_*fm_#QFMQwv>_#QD3^ zr}5_gW9S{w{BB65GPnHY*mlt%|3pBS%En^20(KW*J=Jn2QA$t&hRA1j2j_>292>Av zqI~~r`X^rMPovM#gJt7A(~x24ZYo`z=qIQwutv5pdjD6V1G9!YB-i@Xd;xL5tAbDC zQqKV7D0GhQB+7tYh}%senDCkV$xJ|U7zWodorHz(H~&$`6x9Xxnzyg)un;YLg`49W zyqDB#1qYLX@5H`&BgL0xsqQhvee^slCO+}IgY|5jQNBR_Y(T8vq)|o$w(?l{;jE74p_7@(7{NzVK_lT8TYod?vl6(M|8k@1L zd{h4&Y>ht*s>0)8GoOd9fCB*O;0JmMlAs)s!Pfe7$+aZs6wpc7==L*p^Xe)LPJXHLG zF7aQ4v(SZJx1u5F+4U5EhKBJjygk-~!=z&|Bi~M51J%L-{3|?y2qPy74@KQ1El~*0 zfxDu!@me-ku?;#7G9;;ZcgQ9#gYKfG=y-gEC`H_reCwZ0oFwAJ*|65%kPzbrI0n6p z)?ohwyI47}7&ZaBh$G|x>O%OTjbHD2hX~;m3jxX^InYy|2IYCDDO@3;fGs zo`8E{cSXDKh1hs{BUvhI>GlfOWtRY5>lU$!JO-$&{XhzaKzDO5px5$LqO~By`hYdv zapS{aCq7n+O5DMEp>w5Q||eqT6=|3zLx%7u2QSa}KmBpM~&foTa0A&_~J?UHqf z6LU&GN>;!XSiQI$4k0Va1bhITkM)$9LH{QZ9V2OpXCoW%c6f-m0R4M~VjZz0yv~0R>w~RBydWj7B;Fto(34mWVnHI% zx8S?8L6!Jb0ut4ux3LM>|DYB~0xkzAoxhyXr< zmqK6e9KQolqmJ`acn|+rhyr{78BMSYxsJ?BKA2g-Y~tQ_w3%P(2<`qn+ z84%x=;5q0P3?i1`|4NpK3&0#-C%>;~B5kVZr}8NpgZ$_(<#W{zMVMl;GC}!D)+&#c|M+a>m12A#jxj{mHcmha{gL(j^xs!yfj!n9(2WxML?+TrGlW}E4e>56%{ z{fuL+w;4Yb-w9~=kE92r*Q5<3Udea4OF2(-K|4m@1+AL$>qlazA7tir2H2UES>6ky5D+F@Y{%aQA1)|#}U!F5rv`u3;G&R z6Zj=ybU>ScN=>v%D@(&S`Uk-aggx9d-pl3i7rmD}F6S3J1RC7+_7vwO^L#694YI7P zxnHX_w5uLg`lYmOnNT|C=es}gj&;VVrb~ul<~Odl&P3}~mqwTl{TBC@kCLnw)5HvH z1+h}R5WgjkRyY+)l*L@TsB%7Y5AiqyF-73tPk>QXRF$aZxg4< z-XsJvwfzvee=1u-*ZI0oGd(>#F?4hOA)Q2B@xJ$psmI>eo~y2TZp`7aCs>CVf-1L_ zFE5Sy{q+0Ry#6J(>q`vX%LbM{tzzo`uIo{E+PH6I)SXBV@XMA9*?ceV0b9@92GoR9SB@nCq}}cH*R7rG-JFOgg&G7p{gJ@8bPD`i zCL*)Q^%9@>wzy2XN>(SkAsa4N=>`Nh3GN-p2RzpWsV^%3lV!{EWF_FLlgsXjV~_#< zR=nI_26#YK{?4L4&`!}<{t|tRen$6U(}g4aIldov#<$Mnb~Lv=uDf6Bs+Lu)Ej^rH z{hj>q>|H|P<%+*6yBAIVonKgid;v-EPm_jLDn zFC`oyQe;EqZg5U|i@Ym)APbiT$ps~%i`Fyh|Ea1~yX2Yjp~|kxEAmN_Y2e8cPWDHK zi3S3ejsl&Blmc?#aepJ|B>V`f6nWr{$T-lDvmgb4s`v^Jom%t5d;@{Y_`2&*lMp13;p8D&R6@Oa&aa0u6G&ejkL>pHbsM@gFsQMsh zC+`&Y72HtD%T`IEBwI<1Y=K;*MgzJAe+ar723d%@nKAoh_v z$oAM^ugWMrSXh+b?z`br{p-=Ms=g=XZ~A%fQ}^6izdGax{*EcunP>ZsGp(UaI9q&1 zWJA^}ri7?MvEctAk+_8D)j?SrSaV6STi#NpRJjx7NH%5|!38c1Ira8|!ZpP-)_cqHg) zs5&e?C_dncE=`xD3X}dJTS>jL|4CYrS$G+k8!I3g$cv1FA3$4?V0<8^hPR120Cttw zpCr8Wr-JT60pN~&pk}!K*zQ{5O(}+bRS9L?{+!H@%$xPC?)U4`+eKr4wfH@;=u+XM z0!_(>+HV%fp6*ezQT%+y$@~Sk0cP(kIi!Cbd?TntpQFps!+4PoQ8| z`T|-a8Vu-?13}Ar4j?i%_OB3n2wefgu9fgfmsmiVtsRV^t{z}$Y zCX*SZX31ZYv0^vT4R49fM;9W4ksE-$ngFyl!Mgpf+0Ipt zEc;X2Y3l%MhUL6D!hFp%(A3yeZ(L^DZdO^^T9#UCZ1e1YJ6<>^xxaWuQ?2M)CK9Z9 z+WBXSt^$gx8gU|nu|VKAcu&kB8;F0BiQ)m`G2%FJCHaM%4txhwh=#;Eyc|>2hS{*~B7!s1Sn_mU&h-m<2k zKh#cfM6p2eR$)~XDXu6iiss5T%4N!QWq|T8MK0)L#LHI#Q`kR%l8#7H#Sl3QpMX}u zU7(%*Is6hpy)pRWss5fWt`tX_?TlruIny-Ml)=*|eqnJzsk zJ0p)#+*b4jrm! zCB6&)LQEodqJN<4fZIHbZ^@2h8qj%EXKx$#TIU*v&aSf!w?fv-7Q~WbPBpK#EU}ha z1MHXWh;4!Ol&dFI;ghli-<3CU!NLmY6*>+dLJpA3kznGxWE;>c?IxwAL*z~5KCoWh zF0LSt;8EBntPRE^d5~6gN9f2kXO7Tam||u#eb@*2HdD8}FFhB$BYoqU)9iDuiLlsz z3tYutkQew9u|%dfO2k{Yc#>?Hk}Iyr!O~ zUZGdg*9&m-?x!>Gg#_O$c$GM)` z%$5PxJ66(aF|DlsTAN+-r7qO?(vTl?GW_C@Xw)CFb}KN@NZy6?@fIHFjR zBljya)LLyPEvG)G>aO~tnyR{>{G^a8PDwui3Qz^461jMR9nXC6f$!JV(!Sh^TK~2@ zw(PXNx6Zb9xBjs1us?Ff0fUi-adTqPXyhl6Dci03t%=tU2s#uR6A>6SFlI~~+&~)F zB6fMq(CDyZQCE7ZthOesenVYMT@Uk3 zTbWH}7r7U(=Y#|NCebD0yL7#H2YFk56s(DU0$M|&zO{b1`l@1sDogXXvYTWhI62A$ z{U9ILmha0o^%-2>9K8L9)nJ-o9Ap}1a?~%be{7s=`kyJ={L0o75OPrOYkDK>B2j4* z>2`UtCN@YP_AXM0S{QvadU#CJ*blMCVz@|ERHvx(;roKffYN}ofla}BY@hO+yp^=6 zxFyLDN+Jy13vYqfKt8~yo++gI13)f+CFtkae6u{G9Q&;X;;oWNi6%oC16y^%)bb=fQVPU$OIwDw0(v)~~?T+rE| zi2+estLCQOs)yBq@*m)YTZ29nB?^Un0K3oA39MW>%M44AX^N@PIMA@J_GT?=XlyJt z{W0IMmfI_wue~gH2-*HUd?v3MCp*;7GtCwPzENVA{5LV{Bgo>{%^j9%jJT&FU&~L$2kjj3dh~G4#8St zjxY|YYhII7xuEKM^^%$>g1C@)CH>n&I8V zn*lT6v;Jm4l^zb50s0 zc+2U9EOp20C6=4+S?qIvG{1=bD0s;q>WIJ_0p~!zu|7~7x-~)+ zof5q#Vn|46aBOf5>l+joT9nHda!q@Vs6nF=s~`Zo@A%Z z;dM@P#XGy$=9(wfH?P|Y7!PpmkIEvz?bufxRvS@uv8+jXT5UG)iMF#WHTSnNjv8+O zIO%x^2;P^WYY0S+kbhClSNBqpssdH6`n0Y<&+12O@2YB*(dt2}s|p>+NX`NV(vpekCAx}TbDaO!!c3z2+n~a{$snq`S#_s;cxlUuZk0iG&#OR6;Los%{&iEq zXG{kt=T+9m&H>&LbReKhwi8_wBB5i%Sb46}B)=;AE{{^KQzP1oAnk8cT~~Zni~vpI zY;aZ|C8;OdfQbT$-bWC4xBn4X*W93w(wBXO-Wqo?uz~e-h#i~38T19Zn3?1CxH?iN z{QU@#cr$KA-{Yigh&mrUhokgb?HYBFdb@r~usp=5zoDL?c%|5{ey#~nB}mTzw`O~= z16qdk7MVaNyQ?SEvkdrBW8FFSv1XZZmazi(&fglF*BPohRXnYjT;06(o#AzLbj7HW zmX*&9?`=hvp~km$PSbPyG0z{WoPEc&W$WoEelL1kLdv$0H*u7_BQ+^jt1hduluLn0 zX_?fih*ULJo|C*GEOO=vYq5t zOhCJXj&lUyz^-NHQx;EOZ!+L|l{@BIi-9-MY`$apX*yls+R(4&eoa@y5yKXvqAsas zQDw&(#&Fj9#g=RCX}oA6?1Md9sp0f}<{0x4^rG9s)%ahMn?wX2ft$tE@*&FUip6rJ z><4fpc9ks#=Cpp2b+`xn53G7-<2*VZe&(MnGz5+Ezx?Tfgj-Cv2J7nCET|8$Qclbr z=2!BA{cWKEXbNt@cMvFeqi#^N0DD^CKmy#B-MX{-RRQjRL&jt?qf%g2ciLd5p&L$>qR}?TrV8?w$tE~eK8gphtcI!%96BFP6IuBg!0TK_U!oWJ-guh3+|CoO zaJS7h)V1Gb@W`mPbSb@s9WC5|E+eNwKEADNi=s&NU0WOA3@Q()2fOwa&2kH;;DZUiqK-Xg9FHrJCN_HoA8KwIPZyPw2wX1EI<=WoUJ9RZvzSr5~%iukNTkCwn7i zaRqow{`Pw~hS(6T(=8R2!L~P!QJ#TJg1=Jq8=j21 z$b*VS`W?aZ!{CUokr~ndi*4GlDB+*P;KVNpyW?-fuZ-6;oEQs7>mt*`e~07+2J6PD z_QAPh&izK#SaG-EB@ZHLRanXRZEK2~}Ju`=_);siC+*@ulL-k~fv> zsvWie)NHJi)HkyxI5&FVFr=^%`h+aNadDRHlX8GYtSi$m3u+sBCG!e+qJ&mbbzDNe=5aEEx4KKxDQYx7#Gs_#OUa9Sx?b>PDOwC4hf$ELwwfc&tP*bB` zshY3UDn($Y_CR@9S*XAj_hk>I7D*=wDoGGGCnpkLaR|7JMx)~pBRm!Kt29u6sFT01 zu#bPiwdP9L`)mt#B=eZ=Oj~_dePO;k)I7jpSO_?b|53ZCgVY78fX<- zYrg;a_WRcQcKIIqvVALP8{LlK>0;VR&u7}P*Vt&TJNJ}R^KM`mI_KX3yb@+;0^A3Y zp?$HT_$cCU5(16q9Enf*OSV+5QVdhXDsz>xIPtz^oXYs`6O5A&JH0<4KfOg(J?1lnUDDI}sl z`cC+I_}cn{d<}hms+#&uS*UnlQ(tf2W?z;soZd?JX1cL)9L;?PGYO+#3?&2GM;baD z>x@f?JH%u1lz6VBK++w!cpJ)&$ueYlvQx5tAo-Oo?JGSaafv&LH*9h}L8uuy-$% zjFN7aK9N>QJIHp+p2<$i2FN~21EtF)c_4}B0qLyn*d4%fRzUUslR_$g6HslFXo>H$ zcfY5rJHRD4ID3+PnC%YOy)L$3mb>OJ=1|KXOOZuk{nt9xw$pynan zcw#sXNOMRvbPzbj#t2H@4W?J)nD6u!x+9%T4+13PlXL|0mAS`0<_dWnkX(;I7&uF6 zh#w~^$hIIyc3t*aeqFIm*%|B^!__m@ufbZXNS&^Jtg;An|C-O4aMKZEFQe3Gs}D9FGxjzGm`lw+EfU*u zyT{SaH3Q7va=b~ty)@2ti4S_$%4If4tqnW4}80Ryv66`HD6Dvct5BpFS^M&Q>7Qmm1fBv+(|WHaQw!3<`Z^1ITltWw@l{tx(fL%}5GynKu7k<=|w zNmSy?z;*c!O-E4pvZ$^9J}>3QFn4|3DAHT#KJV({blYoe!M1(YzE+v_mF1S@q2-&U zlhtiK4NUh>?2{ZR&S+N@pc7kwqxvKD4)91U%nNoBFcuBv^ZCxeK5|@mB~$_%#|dG( zFdIxa32^!l1#*>>IE1^vhJ$HWJoAB0r*XQ(_sX}=w-S8Lf52VRoL&StA4$wECX&ry z#{yEoa=sbxhJ5lL6-|Q%!z_Fq8G*te4blUDi#H(75fXAdd4_yXa%8A@u=sEBzv5rw zGI5Fcj(9#GZWfW-$ToodJC$Gobq>Q90NeBgV6JKl64T|t>hc8m9G3at0$D1w5RH#q001NQ8Xpj*Ej z@YFkir&w#?OmG7FnjQgm17tDsA7Vr>@H@_+HE45eG4=uzuoQe0Fq7TDf8#LGiWmxJ zQX7f?5!;9bL@MFNzv371aduwONU`ipvl9k8BPa1`^EKH(cnHS>1% zjB#&pWjZy^8IH&H2(TVYwSBg3wf3>bS;bb^stG0$!zsYaG!i5;w!xQyjjI!~3VDJ+Xm?=Ol7lp_9a!Rrayz6Ski^WOe59D0ZM)V`D;)(bb zjKD^qYmm3VM${HsBzgd>HNAy3`~}X*HfLurTj`6SsqOJL@ox5fa)-OOxq@Baop+sk zojsijrv!XA4V-b#Nbqx-bCz?z^QSY#wbk_;tm)>vtKDNgWu6uwYyQ$3POYOdDI*mO zuKTNCPTvvi4O=l!m?-u-JCHMR8+cM!22Q3A`kMhy&>*M)S^+me9w6O8nqd*fVw3Ra zcvJ8>9}#(kkC2hg$rj`kkcXc~g6AsPlN?9>LyjdIlNgytTqe$d9|wu?L|Z~b=m;(N z(TZ43tR@yIq7*&FR|$$tk!+SmOO4sd*E3Y$(}0wQWEUIHlbt)|_mjz0cju^m6`2~ptkZ(pymP(A`BjP^dq2dg2 z8*vT!o0tMtwQZpa;ctF57YLq@ceppKh%Ezp_*p)+?+3Nn_kodtN$_J)5}+inL<_L- zL}SS;`2*Eq%~4&UzIouQprs+FL({_Ugk2497m*WjD)(EFEXqbt_= z+qTA1P(PvWbR}G-_#-Mf`1APp13!BgpDB;3Z2TwiXTXn%zsFV>>h4%Qj%?3o_fqg& zI7z4Z|6mAMb)QiE(hkt&Yb|Pza*?8m>Ljpb-qST!i)BrSgTV1l3T=6g-@rZe`CYAS zTJs;{2-8H}2S)kicL*V4c2N6RcjK{GZ~7 zyaz}#OpzGL{=_#d7hQq80!iIF5CO>B^Z3u~Nv0>5C+AaIYK(V^r^e~9xy_b(mEm1Y zLDl^7c_oL6lm%fw)OjPmuglk#d@I>d{HZ`zpeTA!HotbKWxF%oyNUWtp&&hH_UGbl z07BL%OFWMNs@>0g(wam^GEXK%x;>Y z9e~TQo^`Xe?0mK%uy73LZ2U5zv;P6#g_m(jj0mjw6y75)spF=NHrap|I=-T^*iv-t zmp|`g?z`_le$z#Fi>CZ+kQedIoqwieU2U3aiPdOt;rQS@;~GYN;h*9PMWuS3ri*;N z^qABjUjZ0Nc>(Wr(dxgXM~E%RCpZi4fOLh|_`5QzTm$TLEkiATSx4J)*fG;c(`$C(vrmYiLa96$x+Gj#!njUYxqyh z>hN{J2LeR8gX)2b2GSkG3%F7+0PoE;CW1LYJAE?hhwJsq8iwu`3Db)rf~Sk%^?{g@7+zPgvuBhY%;Ep@Wy6S(@;OJ4)0<}}Ge zkm`s~!P4Fs2MjrZp#u96#15{hI zv-Fd}e!XGnqtJU1KcZ9OMmI1vyc}B_`75$8YG%aS(0@aA24@EK3T&)xEgK-7DA^^+ zkZ6gg@J!JOU>jTNzvAC6bYWe5e=rRW=0yCzz<=5W`VVB3)&Tq4So9at8fpxjLjMAy zq6;Lz1G!KJrKkF&)H~2RYvu|E`~ENHC#EL#g@$8>hjlY*cGT>x-C3)tDXr*IvABF> zMSjI!)uRpn)Ei93x@L7#45#aHKn?EX>g}zgrg%#|CRey;jyH_XVKaaStF`bKuwT95 zcL)Wdw@4Iz9_(bUfOX|D;E?%$6rELg6j{4Q%dOozE|5Tg1b25Gd~kQS!3MX%-F^6R zcPGf;ZVALYuH9+Z(zDN#+g!9%SM9yO_gyQJu#-{DeOW!w5Xz@kG7i}WWwmOE#;4s2 zPMNl8uW9@0UKuwS()9;SY}jM{G-E%b&2ZoN%{WOvNSAE<4bEsjYy9f}DOxN3m3LFW zmOmjL(m&}n=mBB@umaV`HshPH8E`jbFW!{A1ucQT5*v_Q?lBxBFXA-c8&Q)D#QFRa z6i2S|6?_+AtT38;7X0F$8~DaO4CQ$yy1sb+=M7Ykv5t2B;k;PgrD~MTP1po4Lkfv+%)jJKKr2{I zF9CbnmhxV@B&NIUx#mwmh<;Bskabo>s_QE_`D68UZ9`cf;<=(+269{2U}X~>tvUu6 zF>d05;=G(hi=hm4uBH@PjJ}d*%G!&q;8U_w^cVR}%!MwX77(@2USJ)(gqesRgYD=u zenX&2xa`~OF~F_3-@)#}LE)g|l#{oOa^!p3xTAu(ffKfC_BoFIm28oQRRzCSK5>;x zCmoHfQ4W)9k>{HKX|Mvgb+-8vys5swgXetf*!}DZ!2=HlHmw&zd0=KJ2D-~Om3k2O z_@P1+B#C|8nc`I043OLY6Fx+L7k_v&kShuXS_{nA4;6{bB{09PfaB0`{5WtS-qfE{ zxPuvdV{H}HOI!hWR(#a<#M~hSBD8zu575EVC+3v=Csu{uRbP?n5bf1F#RS->`d8MU z=8-X65&;lWcvrXwkI~Kb2lzgokJ{UEQ1}n6qdkZ=5EfF&wCIi?=ikFk_%Ylss*{J*|^JgjC!oujoxy!RP2h}?pYr=CHtyQrxti75>&VmM!8w? zm!Ls2P$LWFii2VvsJnYv;yXD){tS}^tjbu;7__CYA=E6ah49@TEgNAxDPz%6p&ztL zTS~-XbM5DlX*F7-BkhAitl>F6%zoH+n4d4ZLLRDG;9gFTMW$4>wQ|@^aN1$_AE~|$ zo#a~i>)W8BMi|cgXLWm~`D+9V#kFN~@w)+xoz0D7T`b49Kwj8txLO5V@MtFFj1Qvp z8h$)oD%7vws8Mp3{pNnhbx^m%#`tGcqR1#&i@-Fl5?3_9#?a9McF7=p$4Ze5T+SdpXMuUb~UlC@9`;QoZ&j&WTLn~}5r zB+&U8fZf8!1=m1_*h7+uy@QoI@xZKDxj58Xk6WtJLFl36>%K zgu$UyLmYbmog!R@4usZLe>NSo-b3#NM&qgFyPV%O{({BMWXxFHQN@;YFS}*V;Xh>G zVm4S4#g}mJiUolp{cg7=bb|SbO!Q{@J;vjqHS$yVZEqAdMy>UYflg{Hpm{OJG=|^h z%GTxahXRzegC4@S7Bhn#)D8)CHzrf+D>mup zLx+{8_yEyNUJL51LD~E2fbLOYC$kf#kvE_p+!Wy-%?4~?newg``mi3ts6tI710$|V zu@FrR0q=N)5_nHOmdQO|Wo>-Noh$uMxVyf;kPeko0u~amFohwuNce;NZ?9PHlJ8*r zj*)?5O46xSR-^8mvM0<*yK(}KX(hqcd(2rI16EJ># zV0iR;scxu|X07HcHvwM9e>9hKGtf4`6k1{G4aZtntEZT@cy#bSY=CSy)~@UgbSvDipCn`RAJ1ZLU1EK5kxnfUjr}0rm4Y+@Nfs1vtp-W|>Lh@j$xKF=1Fx@#^nJI;@?ZXYK{+Qns2JZH>q;KZ?ur_3-=1jAZr9HbYou=X*jl& zZKq7($`FnpBeqbdyO&kGN%`fkw1`n)9B+G}yXo*kh|ya84s!t4Y?gSFs)B10NzzdH zG%P1DMKng0yLt*;>SR?0#h0dRhz-p&y$yvEvz2==k2qadC58(rx)Jb4?z^J5&=K3I zyun65f1-^QSoHy>mC_RWEGV>5@Gxw0by(IH4@MFXf0?^5iL z?qOEA1UU^4S2Y0C^@Z}|p(Bb|DUEwCKZ(`^WPM7~DAK7?|KDhH(|kpNob55sH`mNMs~^hMx+@*fy3O&E{~tWMb)YD?~g5}0X$QA|Ck6J-q^#KX0y z>nZt#DD+0icB0LEy@)V!6JLNLjEU_^J>u7+W1;Ef34gSzpBT&jEn6iSxI=iV&>3`G znxQR1t$`bNh zf4-j-Bisd?rNREr;fbEH@Sy? zM70mj_tjR{6$XR`=#Tg_iN^3B++?Z_yi}Z{nC;b2YTdpH1^Q73Ci2(}%@E%?zgzo- zo9yeZz07xVuGVLUUV0hX9kIW3)Zbdsj#GrfWFPETRrCF`@m8>#Z!1kA{s#9X7yJiw zQ)nR#Bf5vuh`s!A^ptm|P#!>tlfqr%i*%ptznNlWh|pGcK2TS&ReB_J2gZ}mP{7w!H7^vBZxh1;^Hui% zcTd6&ieJGU;S%*G_ydQasa&;oHG3yGR=$$@<$DQ#rW_b1EfDsr()^k5WB8gd7(2x@ z@UGLOh9V>*VF70femFop3XN7YcE{o0NQLLEkfnIwd4WZPoMow)1UCvy#vn)tZGnu^ zL}UZ#!L}k9&sDfNa)N7zS-I`x9d9sUJ_#WhLCrZOZ7+M>N_O^z6&gbR0|V9ifJ>CZnHA{DzzFEI z*cfVr=t83r5i1U5ViQqq;1fR7pTZCvhU^kk5DT2lU4&MG=AQ^$?hxEc+(R7a9UQCZ zFJk;OGB)%W{wZJJ30L@~0bH_rZSZflzoMV#nOmL>60LsXHbEj@9T>v3lSKz+ z!Zo12!Hw8WxVF!a%!A~<%k&l~H@F3QPb;`SFt7o$72-@LUs%MgrD5Q&tCaO(4dQg{ zqu2w+gf79q$P2K|dS252>t!j@Zi6>hlRB0EKj$#xEU3IsS6m5fZJnsD75hxBKqqUxu*7eXca2RxRwh180(VqMQ#RU@d%XJE3h#=)ad9mxSW z!gquQTxaDGylVxd9I8efTcLPWC-D%Z#kvGOfh7DXUny;*2L)zxE49`BWl}%oQSlFm zK}T@|nN4)9>S5SWrf0B(U4iJxi6AL51GtVe#EDRI79!tsT>_{gFK|p)i2oFt3txE? zJPw?P-sB!YG`^1CEUYDmim}0K)O>I@|)orh#WQ%i-Eg&E*d5M2~R=a01K%NZG!YemkLqf>6WB@ zU7QQmNB`p+3npngN+3l-IGhK#5?R6lj0f)9EkY`M5wuh@g#=)Qy9m7!E2U>pC9sm5 zm+WFYXpq<%8iWKvrhGjFLtFUSuuD2FREZ0LeR~3ImL7vNdQUJToG8=)*45_1Y>@iZ zNiz{O{I_%rS_JKp-om+%O!@(rNoSGz;Av6~vV=wi+n~;p-%vFf(yXG*q(0-?9~@Rnsl4=UaE#q;XdpX zNIkLWeJ~H23MhW%$W-J!l0h86E}%WJW8j&y1yy5%u}j!Yd0vi%3<<8Jei;3Tjf3pW`;++mcXx2GhsD)6H6i9kRI7;3Na>pVJnnf!zp&F0FHNu}oZt~7zth*-Iy?7OTJUScx9wj}|BTLh zRMgv9DAc9ys(b4$g^7_(qWXrf2wxNNI(%06ePc+ymA(%D8JgkRQnjRPQ7KXZ4zmKw z@j)6@cuc&kj;}7A+^=r;T7%+7M!yX&H%`+&mj6i%N8gBP>@GkH`OWTuno(C|=Vgr* zTNRCDqj9ijVc+{!dx~8qR~z?n|1x$j+a`G3_lIkaHM%$_qwGVs7v{&a@BO~h^vUF3 zyMm`R7ULDu%fx4m&os?Rp!BU&q;8culNiN+1iY)-)%8kFWVQJD_Uj*CyZzjsTUxQ) z8-w1`eU7_VYkcjDT1-N6Y_sqeniS9y8G{}LwqMpc*}AKueQAw?p_ZCO|5mnlZD8l( zpOq&~0Wm4`m-m%*Q0e%5OPS-gNyCUDq@QN%6swqjE8}`MENNBLY(&be+P7ko^iC>5$nvaq&2uWMTUy&!>@QUXS(U+q?MR_Ao^D6C0xrMX?uj&tKoouagqq48Exw5yaQel!UBah;Z=sU93 z%qF5UK7)8kpP<^}SJ1IU9eSPm{_G#mNOy-pKqR z`CD^yvuatcmQVEzfjcPDBmYbCG?F!|Zg{WO&G0PwI>7~)V<*ZIOM8`RE7#kuIu^R; zdX(PzJ}7iV{EeNKr|COIC}OgrVog0&3keq=aIdhos;=oDgkIGyh`v;NWW6`_9JTvJ zf6~B&jqes56POje%np`rW5?;svL~_@ijK-->SSF*<0JEu@T1}Xm}+Yd(-vq?P~%eC zHdhU{Ua&oI?)3cS8|?jM>sS0OGx^JdSG6A!x8FTqS9)Fb+&hi$)YuWnTfYdQ|by!j2UKyaxDxvgW!@F)Qo*Q~zAO#x<@PN6}-1 z3*H;mQ!DCL?6+mOga6M?yxgBBz+KV2lZep4mKG+ygjKy56-o38MS4LOX z^&G}lm}b<+_haA0~jGOng4zw4S5=UgU<-3^#S+YW-D$$1Y zd4_$4J^G)jH8cl(4Qy~Ov5m0(wzqU8dtL_^eg#-*KXW`Tb>JyZQ_LuZ9pNhY`hD%Z)< zC)p>{|M$7Z3;3qv{D=FEOPVOp)L0aoS95*KH$7xs4%FHpFBDpW6W`YM)WUb!>wnkJ zJyE1C%PKM#>8m>nlVq*+DUmG_@Z@_b;kAQdRk8#j&bg`dcwTb0DaTxRwz7f4?M2vO zfX_4(4kr?1g~}?mpiMJ0FmzQufkp3Z+mP}mB~W>WuNg5wuZrtZZ%pGxjdGLv$NphF zz)TXm1z!dpv7{gu20(w2Llu81^MP${yFyS-P!sB5ip$g!v{0JDT@GCJOmu8^obc=g ze5@G0R0tD}ahA{{*PV*%`OmVx{@nZS<>R`ydf(t*v(5zLjfCMfwMkQ2U+SLPF1kh{ z)l_K9?eVR#?kGsf8JFqLwUi`P-l@1;RqTC-*|p1}X2su3I8i$#@oL0JSp~nx;VF*K z6SEd(JMy2E`|Vx5UxO3CivAJ&1U*IDRJ*k0x^R659jfj|v<%g?Z?BwD`9DV-zlVNp z!sAEOJ>RfR!(g3pvAvBwl@)lhm;mx2b;ME9AZQZOmg=dRt6iXduX(9`toyC$t!zah zNHIIh*T((axzI&>=6Tx(Rl;^C29AVM#a=Asr>lDx-pOq7wes!E)TmqQZq0tpRcUo! zYwS-fPB_%$zm8kl{ZaR^rZ>ilVeAW6Udf+XTYk>{J|-h4FQcel#ZmVutk^KC#)3M@ zbw4HjiuaorGT+!K_HiXg^J^B2ESgz%t_pQn+~fTU?uB>+T7$KhW$K!Qhazr=Ei?R7 zX^ESmA@+F{HLIq0S4qd{mAZ9NP{NU-TD~qU?oc~mTCt^y0^L1?i^oauw(Endr{~qDg`n0muF|C-EuPX{?CW+ z)1E6HpS`!}(T(r3#Q(+Du0Olh$eI=P-ZXWkoQ)bL8;3_Cxk6!Ji|uEA+|R$iz5IQ_ zvb;oXT^fv$#YKhcl&18l>#lJ*j8d&b-2O$jJ;hl? zC($D`&y20Y7K9;&k&1lm5})Q>>TK)zn>&i8%I<5;VLc+T=$bLbQMJPF8ao+2fLW8k zECGA6UR+`5KhRMe1Z7j4>Vkf;sjum<{+n_*y$`)9G!1HfG2ZU}i~!-cc^0_8y5D%? zJkicZ*7YU7v$y>;eY8A3_o(sXAzvQ)*T!6Hye5T@9UGUGfYo%GuhCy&w=@ZKzqWI; zs_$h}-xhp-k@a6mEvw}1gNKa9+Q%Ej)^CsyZ)yQdYh$E}P_pZWb(Hll#}DUS>+h1~ z`4Rd4vUJxp;U`f~8E&i&kB&HG!c_6(TM1*sLkGoNqA61cWb_ouE9%plBz-N@Xp>+X zW^NgF+t@(Wl~^P$uz6*W!J2t7GJK`X^(x*~l4cBUy~-(S2x=bRxLgonZf2 zRj-UKxRZOwl3U)&+W?s)lWA7!9%=h%4lBRO?$dLr_S8$JAkM9a97sbzRAcD{ihN~^q8mubdht^fqS&ikCI5pwj?@=Q zyz^}fODgiBb8@o|EFPb4QvgM&lcuz3jv#XbA&dKfg zSMh(Lz21SE)TeJ_P*-pKB-%+9Zh%F9v_|%d9g;98t0;~ zh3A+D7-D60`K48p((k2?xzj6^{^2YMxGwXn@dSB$?G&(%*b;d%etpuj2CbVcOuiI( zf?Vh8QK>6>m%l52X`xj9(K*oP3P8bn!QH?=zE^S4_Rk9CVNS~!CzvFz&U1Tq#L{w z>H?>O9BC_IhJT#xa>4xX*Iv-~3U2nkD?DA3zJqNMz9VL#yfrsdXsY}Xu`|JwT&MBD zWH|CSannN~7 zU%m)`F4Yyja;?~pd?{27d!&3(2lpeCvcrmaRi>H+NzN~_?*srSz+0fjNKfofz{WfZ zx8hg&OPp{el^^~q@a5WD?SIo=kzZ?;m69tXfpbL~?XQUh_0wX)lZG|Os9#y*kvd6? zbuKQx_&YVjP(00k&;GG2+R`GUJmXX06xU&Bj%rN!wV2&;RQ$c@wWfb`%?wfDM0BmF z%{mtC?f+zJUa_lumGzy&=y?^0;17vGs2f&FXP9c#ay9za#MhuADIu<>xlB_>^;pqL zae~IsRQ`;=pL0`nBWKs(WvCRtK^>FzQS4F7rhg;<6X!vnk-)05sSU(S3n_0{OlHw(W?Ta7!lFu%!O80x9Yh*hT~w|Ld^b;^l|#nc%7qKmGI zsC?`Asb#%Kb?4XFSZiF(4e>|g0WCYqt{p|3 zSh^ysUa<+{F% zc{BaZmQOu1ld42%zx+9I(X+Y=Og@=y5wnuEHyYEVyl#!ilQIWd2N{F>z>?@HB7t33 zc{V3KEg?NM&sl-H-v<@kw?K^7<{u%xpuZdSH7?cuQVWR>H!o0XnQ_cHS+*ipS%g{mGqKCuZztfHQJsybR5W>&;bPMDPxozk}9vy?T7esi@f4HL1hq=FoaH%5C& z-$H|ZTReBXUjpp_t?{l@Q+mve@F#k1dV2cb01nDW*R;yRxt)Jp{m|oM;9Jeii=|j_ zFRTpRsk&TtHn>9dC@MZ-V{J0=dem{vTq;OSllzo2wNct_j6HO(Vv%LAg)Dh$Tj?I} zY2x|mEVS*mmN^Hp0sN|ZxOqkR2U7#BgV{pCiZJ~{Q&(fErmO4&YT-Kj>-$ZjoT?9~ zLP^LIz+yF#g-VxUNyNEmF?xGcMfhan6>Ue|1mk+s9NlKcD6$6vCKbXXVX5#h9~L_8 zJ>vWuG>;oPCV3|LZg^wecWldwPX50B^~L9^AHOqg`TNUDeDm4Gjxhyw@=tm$D9^7}GyWNV zLB2*cT5%F^eJ@hAmGMSL#Ir~&?1JW<%tP&>KGOfmXKTJ_R>@M(KSA<73#mcY$`#-Vb4v3~S6gdV^^s>W8rfpTOW@8TzEk112!hlRo7tOY;A&?uVZWYY_7+syxi4D+P}56!kFObz{6~ zlfj_Oga-%z_GWq;1Z%M4L!Z5vy;;TY5^qUf#R-Sx9S1mSxsq}h6-=PKgIZ^^-2a2hqNYR|qmzpzTf;BG1+%>M(Bp7B# zc8+}$Js_-(Zj^E{Q=hKGOk|97OL7RdNUFgu3M>sh6AnQ0A(1~69PZ0;cLo%KCxKq< z*x+sVd}~7Ko`P?=m$DyZAIo#+?z4O<(b(!&bjx<6$qJgduLxVm&^mSjp(-IC;BR&BrKHIQtqM0LsLKx%^nc_CUzglA`kVYI|bJ(&ra_x|3P*e ze~RnDriV7MDSV-%Cw|CJ>mQqU8INdcs8jULa3SVxTtwWDb|mV;_i|R4GgsKEt)hag>AE2&Up@wU?er{Rd7#t`ahOp{y|*Nap?%R02ojk1rLUlTn0NDG};Wo z`vJ+{!9Ur*FYrBp`aiqX&P4k~+ji?z>p80u@G>Ub=i92P{x&rO69MFk)itWa> z0s_YW08Q^Yc!G_=gZMLI47rpX4RYe!sb2IKx&e51TV!VdEjgMo(+#P&WIEZGdI0`z zfXsLqXa_im&xAy5AZt=}sGp>n><_vDv%qTc2YL&c0Ve>~IXDsF*Rb)SUxB^;#l9!r z1n&b6>iOlq<@w+~=pN)T0`|P$pAzU2xEYukXa>&t*9G&zb(PND<|cBhxkBJ&?j~4; z3*s#4m6R&2mu#TRIT-MfN22#oAKDGOiZ#ZI@ms_I@+%1eg7J0o5@`g_3Sg-wn}XLE z0u-+Tsu>-kUoi19k!c35!gPwGw$ojh7Yri1%-o|f`U$n1!l-Sa>rj_^Pwk;H$TDIw z;DR*8`(b0zT}Tx?7#;+D0lVfOTsy8BaB@_kPl4P1-@atvA*<^<v559s_A<5_i zGy^+Mw52rQ{(pzMN8P06gQR{R%16xx{ggD2)=#1aP|qkG{WtxJX6Xn<1KMic>8|v4 z`Ul+?u=~~m)ypqx2o*yKq?wvReWiX;J*ipbVWI*bhquD(;@MafhNAiKMo1;)3d4kW zkQe`(OJtt}hX#84ulmA#hrKagt+$uAq4y_93)k{~@W%V*`LzCN{!#wf{!{)wfwsZ! zp&9IaE}6f@|0y8iBalk-LnDy>=r?o#X2*8pm3V7{Bi@r;s1x9Q9HyRwK3*c-lx|MX z0}YEgG)f<#W>AAb>)-*Urkl{usR(KZ*@YYqh>J8el-fvrq-3C7R!v2M9@_-E4_!qa zpdzV%NCS9EE+)DWiG-YZhI_CtD1n|vnj<^lpHMx}`|2;g=6iDq?5Ge5dO7og<$zQl z?H}XI^1k)9^u6`<^uO|3{h9u~{+)i-|G&V3V3$xF`;m?3rf?-(D*r|JE`E`Mpf9x& zaUiLn9R?8%$##^I-UeDiYiVFU1J9gtx{}^Y>*!;ETaQvScrOTjAG9F0gWrEwM#mJ> zw}5Z;A$<|FjCk-K2UC^66#t#fA?@IM(NmwuWb!?r?KC8p0kRHFPA8_|tuY3TLpH&6 z-~)i4zd(!vq%)1KMhVz0$J(rfkPdjv*-bSkZ9Rx<_6G(e2oCQp;&AoF@kUPt~=woq1Ewt;yK zI#Y*%8~q`{;xq71fQ6w!laS%?Y$#EBDRkm9*%6_lz$E`K?=a6rm(Ka!9$~*!?Xuad z&Z@mtp{kSC(Y6b=C${kF-u77MZ&#jsBtV>w^yvaOf zKbicNyjor# z8~nlGWNxF-M>+$2ge7DW;8MOIZjd{vWO^JumwrmO1Xs#aK&-DV>n$4yx|FHRcqT;e zr^kR_^IAF$JV(ZW)=>`QV!kt%nN^IOc@Msg8q71Msq7!XW%&kbQy(|qcW6H$R z9%W*grM$9YOx1a7c=b8^A!jRh(DTsuZy-6entjR57ivq>pnqX6(gdrlL)#`s?a_J?`_)8SOHA!GrjSP82E=odA~O7a9n&=k1$sF^}$4S1F5 zviq`Kvf;AbGErug|D$NCyrcxY6ZKI|Kog3nBshz34t^KXNqiv-9M^m7#tDdU5 zt^BQ64X7@$^a}DFUWAr{>HQTxk#z{q?6YJB;@lG#O93r6I>x7aOyOPr-B z_gJnvZ&%*)yyuqR1(S<+mewmzsk~urWgqB_bEkNF`jdhMp{ZO6KVMWq`{5{bAvOyC ziSHz4lZBv5c9k4W{Y@_e4Y5be7Uo~@%q=C#Rfo%8#*WksJ#4wv1nm~SQQ4V>}dRZsH& z5u}0nWC2KDltZb=G;AAQh#QF$#1ybXtx1gs_m@|6L$I2QVYEzV@EfcnOOt&BpLHC# zsyC}XsKd0lZlUhH?t?B{*I)0`FE-3JBp9~qqxAE1>$Pn)oNAHsll(b@(sRkT_%`$@ zv{Bf=rhosKiE+a9a0S8xMckAEi2mJs+1(jQxaw;-MoQ^`J*mrA5((lwb+Oq48H zW(0S=Lox*T^vsG!impmQ`3cZyDQyGY2)(FpVQ6KT3aAk4jjS=&WHEL(RvP9SzUt5D zMr$qV0jed6qq6t3mAs38LZ(a2_-UbY{tli$99^tm$}g8JE1Fht!!pVO2&b0n+_yP@ z=S1Yr&eZ{v=yXedL2_}s(*5PTs-{<8cARzX@SOL(3A_z0=NbyX#k$ZJxI5Yj%f$Kv z-fb&jZO$X7QUhowJ)1cT*3j=5NH$JZQ+^-(0_DnFlOP&;2fRoedS*h%HsYN zno7GM6Y>hh@NB##@s2n}wx=dh|55AdVayd~C%7v5Gas2`Su5FG*)Q2kd8)#rgw*{t zY1)qZ(S}=qeR|8d+ceO8&-}>zulX-CZa!l286%9C;f9XYd{rG&9G1PJmy)lsGjI>F zAJ;E<5g2AJ+W;)3Y)&y#xGA45*in#JfLNB~w#sdpw=8d9UVlp{-&j;#94dWW5o6=* zVXlwvG2SEoL!r(52=O4;sq{c1v59ymqJ)@CDyS&hPRD_^_Zs;G#URCTc~CY-o~HPv ze4;w7R%^rcMTYvO8Rq3-`@tZfP=R{h<3&Prjbu#ZW?$PUYUo{@p zQ@Nk&g;&E$@vl&k*X^ubeWCn!K~>hzUu)8TrS1E*D8rI-B4I zByB(Ft(1oJrK%Ob)KfIa)qd4t%@zHmuzOLBVm?P@N3@JM9MLvX7io)F7+DzEHPR9G z+3;F>0<`vHw7oQMl^M(#;wZWTv_YG3^8yb&^_=UheM`>dEzC;F>XS7jd$i?Y;f{j% zEL+;ptY77s9+PjDZEX2pRj=H}Pz3veBc%GES-l97@vV8kv=XFdijbaA8(}p&Hq?V( z3$MU#A>)KgA&7HHb~FiZ2D-%+1R!H5_o+h4eVU-LIHGe@Ol0?PBz#8Xz?k>3Y)p;l z6;XAgCPqa^o(fMezcKx7dS;xg`=81vw}X4k4{9a;0J3p4gU`InT&C*2<@*cbbC+cI z$X<|Dw_tkFnSucsZPWPNXZAgSLu#*xD1BA^7<4)N2w6fjd=#q((rF*1tw?Ko@zMStRmceQ0X=pINB$UB^0A39ZIvL&tZH15E=VYJ2mC#TRHZHt;4(Zu}J0FoD)v#JWkdC=(! zqT{7#E+eqg8*ntOQk6vK8}cUR6lcvV8d&^a?)tP@zb2P8@ke_jD-RS{%g*?_OO1s8 zLOh~q7$G_{3m}ym)SPCEVgK8{;S-cSuG2=f*v{Ze~wIlSlq}{;y7En zE`LdBiVqS}*)E}HK^6Z2*5RWF4Oxq5LMrk9;1r||b{$DYK0{;0yX=jS49w#;OCz|k zfsP>t?4|EYmjHu*3E{$LknQCux_QQt#(ez)gFS3Z^s<8P6FkIbz3k{R&6^=mbc6i4V-@*w^Uc`l3(Iz5DgteRK4qA(_ZRdzukSJXFW-OoGe zV#!L6(tf_km$#xU)mNY275Kx~+}|uT2D*-sSY5ORss=pwTqp%z0n+6H^i({|Z4Kpx z^ujW!oZlXhc@;sKbXoFn**t@^!|q^P2w2ueeMqmZ){<)Edr zOV($1|9R+lR(UN?Z)O^!1QfoG{@jog{zWv#9f%)kLX_aeh!cH@_eS?4 z)1f#(8c_)?kB8xsp)+d@FR&Sm7PwZwbgzFW6?GSHe=jHC>I;)C2hjlV5ZBrXj^J2d zM^~R}q5M>FTEU9^ofXH+!g9u^VOg=ZOdr^TSUAg0>&@UX775Jr{^gq(8V$E2o`Yxh z2Y5E1Db7NAV;GT-42Kp8=Q%U~|Eq@8LWdyj9qpaZs->mu$U%(sosqE8faS(FckDVJ9Gz zj|jLtk?z%Yv~p8%+k!s}c2p%+yvWo2vS;11H4JRBjVRnw@W^&Q^qgxDyyy@42e4Jp zIlvKFjy?yUCXHQ&hhV+Q{@5PK%{S&Y^S8j7Y?d%H@Y&Pcdx-5QJ!4M8M_^U3+eDfc8dLS{^{ow~4HfERinj7Rkc2*m z_d@GQGg!Y*beB8j_74?jikSj!;cn3XeUp3W_dmJk9a!M1HKnku&}SDzbGbLcR)Orm z26jE70hGG`cedTe3edCg0qkEQ4ebYE!WcdWqn9VZD3eqj5d8V3cy4O8Jg)IX}f7B>twodz-j9Jm#668FKcg$rB-=x(b4rD3tKE705X%imQhLi&mZK&J<6 z2XqrTi5|rap)Qk4<#lxz4WkSbw5K&8gT)YOd}m5GoYYln|EGPanXFl_{zo-O{*fL* zhT|R3GN_>-5B>0ta1N-h0PNEzCEF}FiWgLS%fA8R?);*r{^o(vRNS;yl}i%e7IeSQNqTs=^geXY_ukJ7B?v2Ja&C5! zcpn)^y=NxL=F{7#_KFocyWu}WEB$-jPC)kSW@>I;VCrKytlO_G(Uxls>idf0vQf-A z%7UGP+eumc=+GhWGDo6qT4izh@shZFskou7xU4#7ZuYp6lYj!U&RSNosd7fhkBpL{ z!~%Z5xCu3aGlL!UeR4WD&pZxt&F`qa^i@1xI>k>1*2&${DrlQHHI(Lk>?g(Lz|gRs zy~-9zk!TxoDsvQ&Ec(#_#Z3J@@Ljji+x2qOK$FgV)I7zs!GP%>>)PpB>c*=%c@}M; zgVYCH57!h-?9;$N?(G+vSwpivm;D<$5SVEFRr*dpr@JC6mj>H6dW>>)%0!ft2(5O;31x9UWBw~}E6yK|=$2duqI=41}Z$SUp-njcEDbtyYg zbvyJK?Sj?<^u`MiL5x-uDrYHP$Wj#PvSq|5{DCYWcaUk4lY0y3jcXwTTqcYNta0D- zl?vfdA$KE~8)86Wz7)Gi5%Pn~Svp7FUDwzA(=^R6(GYKXZ(bZeI%0QN8&kGnp3z{u zt&h+yRBe!FGT^)!`vM2VDz3;s#eL8|&HAZgaLK29U(SreQrm~p1)0y%(Sl)tZ~PjE zzIcA3oIo<(#jjRAhJdNU^dXl=ea*uMK@&RSRN6SbRCF_T75Vi?VKx%y~vIydX zGu`jp2e>v+0>2J$HUUvboCgmgAIN+RIN6kU(y)g0rbVWIOyTB6=B%*Z5xR&)VXe(? zO_V7f>{w#7RjO>oab_b?6a4|*l2Z9;{^O4Qwiw&+${t0Nb5~|hDy??NN`|F<`QEQ+ zyLeZ;V82`PZ-pl`kvK$6Bp-rvk~R1w)mCkD%@bv~x?K5%jKHSK`m5f_Iw1qNbL=QC zP5g?ym!|pa*lV~t@Y|#yn-qu%9RYW(hVTbsscfR`HysZgm@fS>Q>LklImVoCdK+c} z&W=H0nP%2B${4Me>!P(_->BS6bC?mygrcEbew+WJN&b$2H0E}zAEB$ z*e&zsFnRdx@c9vLkOQ8n+okOQ@>Gr0jJz-P5nT`ap>lCKi~2HMa>s?L(&8qTjXCi- z5yewnN6T;gXz~7X_9lKie%=>bez?5S`xWWSsFbVeYhXpPP?@Exse%=Gx`8Gi(@{); zC1#8~4Yvu)f*-+zOJv)E7WFuLS8H4EIQA~E1bhimfZ*w4UkgV`gQB_mkamP_f^N9( zgt0X8Wpv%h17St5buOK{PqISsIt0PQfBda3+i)<0KFDgEKxLF=~F3uWTKVrA;qzcz$8vbW&sa+-e zMy|sn&}iwSa720w3EX1W*UFt$e*jK)Sn2TGG3o1m1b+NkHqCRO6i@H=DK&3SFdW(l z+zcCp#e!LC2k(-t*NoJsnHCxplphMwgUsE+H_LRu**Hs31mk>8|6hU2?#0zFtj(+b zEpZfIw}%H8`4HRT%0=E}{xFZp#fs`$onLn5D=uZ10m9uQMI za)vQO(_2#$+zG1Xm+8;oTyiv~mm2fc++VzzE%z;VX8j*WR{`C`)~&}gajh#*+}-8k z4i|U1Kyi0>cXxMpcXuc)&{Co9Z5q$`dHa7WtA*9HN#@v`vp=(E+20nA$W2Y{mz14& z{15$WNcs#HRq$U@t#6~!_VCy6AEHg9@a6bOYBH%;2N-Ku>V=MsOf(NC^yDAo)u1-U z0&)+&UMzEWa&32gbQG3uEL{SXhC`XEv@LnPJOREQjx+Y1K8wR`ny!uCcQGM|#co^k_)Bj-V1*(Uy z(i)Bb1Rjk#6JgV&OCA&msp3b7RFsNYq^7R=g^t{myc@-Boj_PAJeJxqg(}zu-pxDh zEj{&Q3#tKrT_W{JXSYfrC&Ly8pAFm_oF36GYBQW*i@-&-b?7ZyFVk-G7^^2R!17jo z7)}^ZiPqRIxk$uf2JN5UGY8-wc!r|!lkC|j^ustBebr!AR~ z*C2mMNoV&}-tTH$5STl%#8g(k_(5?se*`_ANkR<-y&{liozpfhWJXwS@VXEwbXLea zxPoKlP^D=x_d}0aBTYTcGp&Ut+GsN*8T;s_;pMrSzD!|*?3L#5$9+v)Q|EFQmE|ZuDGnd)#)+L)c!=$wAXZ$322J)^k zvL7*&Y)HC6mz@aMbSXK5d_WzeaoS6t1}}0Es{Q*^+rjTK5_keps%gwc&@FWVjG>T- zhgYjXEX39-xsp}N5X?eVez*UF_ojCaApA`{k6dZ4YVOzWOxI`EJxBvryn((%?{aU6 zcaQHbcZKf_inFRhbzwD}STJD`@EJ^?Xp9njizJ*CW0VP?;iQyc#S2-C3CaL8A5F%F zf%|+qs6X35-qAi{F%e0gqvGfe^iX;ReFLzGZp=;9GWC7%tX!y0R@c%r(?kOfq}H|5 zj?)yY>#6?&dV3#O+54Gf>MQk}`a(S>D-etD7Wgwfk{E;!$86Xb%nVxLiD)j$#;$@A zG+y})T7UyeJ7p$R=Dvb&?Pc&S9LLr0pND*N%y0Fb^W?hMcoIEc_W<`H_jgZy-x1#i zpT@V+_rM>(M+xELVW`{&qcTO0E~Dx=MrM=8safDaxtMg5tAP*j7nml^0Uh2+dWmO5 z1!5(rIS3+zn2(>wZh?|C3LJHluqL>IUBdd|Pw`#&2)qtKQXi;P>KJv7ilJwLBdrBe zc50|5s}2GO;3{2%K2I-U1~b2@Yh)(TmY51Ykmc~9_!+D^)(!noI-%W=E#3@WLFbVd zbh0Y^Bd~}fi57&K%%j>sc3}YXkPd*kG6uX1dVuezkp-U+)pRC;xy%$WZt(Iu2TIW$ zcqEptJeNmGBSlG=CQRXr{8jv|{d0T_uvtEOKl_IJ61?>wWhUP5^EtHz7A_ zzr0Lof*wP<(pYHEcGL#y2=$kqqMFMtgt=c;{X!kAX`mUb=?3WVAHe93seiD)*>>t$ z>O;WyD97s9rm7a8`y4~}p!?HyI)N^r8iVJQNH(Fg)HA9k&C>U%>C|;97-mB^Kz1v@ z$u^IEO`in+h6(g#dN3VA4FV1G8sZ>60^5bEqc%!9q+6U2Gr;3>2tN#*?pE?Gc#-?c z#qhs*E8LHxz(qb=nFQzIuGl#I8Znu4lC`J>kbP1a*hZ76s?=8Cd3_)f@%`8qpr{8E}CVbVu2O!_DdkoU{O z<@WM)xtDwx*4}KT68!!Hg+MQq70P(fSl>_vC?T*C+?EH*F|r2qRq4`JNSzFobEKhC zBk2?5RPw;HiBx7OA?P7mfK|gk;H`*NA* zEO4t5NSS0Q31;(kST|$I7`W!6#9iVAag|8MJK%4yJJ@dQAT|LrV&_pMM4=Mk|FwW? znXHUd-oX{~k_J2v%PY+kyHrO8HNLzUk_N^=+GetBQ`Uk*<|5^{qJp*P4w6td+KKtF zd7x%b20!l%d=j=9z9Wl(?mOPBWX?~;Ia4z>>KFO7vha^7I$K+u}|1__{&q+ zBKSNw-W#{#cFd1uVilpHIRmc&31Thr1Na7fIKCQ3z_+sFZ}B^LRiYCSOL&QHWFc{Y zSV5FPYyJRtTnt)^CvFo)$TuO~?zuc)-Xs5z2PwOh-O42;Uulgx zqyA_yQo|=|0e-UIIhSX&KzGkysFfq%mv;5qn9 z{0ROX?g{mX1_a=F#BgE=p(dUKXVZ><0)5AEm}%?qig;xl!%sm!Kf!Y03z8MVbN3>) z6!r)+;L#PUg;j?3;GyqsAp^{=e~?RwMi$63z6D;+^ObM~D$a0SbzoF)m99uH;A@f8 zQ(`0@_i&lDTHCL~frO6yy=?!&Nd z)b=a?JA`1X_N@~U#yZT zzmY2`#mZdeveHc%sPs|Z!~12x3@(IzO@r&52%qPxydR$B0g+oRAD8FKYve9)^UQ8M%PKGB`?4ZmmmA4>QY!dX%#)f*+a)0X$mgUP@V_#tL`sno-f%g-1nx{Taqw}*YvNIO0`>)-{|4i!2lRM1 zd^wyPUMnPcMey*w542Vv)CF8}LSYR*0>66&Ta0~&jK2N&KCCYu3M*oJ{15gEdy4rH z2??CjkQJqX&sAH5lxwi=il_y?0$%GUN`!IKi@1#+Cz}%lQ5)-`bWn`SDXb$Nf}O`y z$|@yCt|bL2^HDkMEWQD&0xaLwcqF+3auE^qnrYe<06c>&uFxzY=jL1;GC1Dx(R68DJ|{64H>X8bd@1wOkITZfzQ zZ&(~!4tYo4rxuoCNc!m*^0R z_+%lLi}P>xP4O-Co^cntPP+Pm$FkMlr{r~U%c60`)f}6gQ(YEkj)QRx^St$%!G-;| zbX7Tyy@s^D0g%+?Cqq=rVGr?6-2uF>+p1TxUtyoM1e~Lnsd_;AbstdMA^H^kA0#ig zfHZBtY8N zH#{3lH>Gq-7oZyQrBuoPnpar#rl?L?clUho3~LN4couV2^#qX4n#>%nU*F9%$wXKN z7y~pFREwD|bSxD}mjPpRDejhj2%>+HH^qC?7sM~)J8(6*hyD%T3ZBirHbOWsJ}%>- zv|5FM+j&Lx0Ik;a1n^Kggb5iR+#8TEDx^;Ez@S^fi-Tqb?6dYWm1&cK`8`>+6SDNG zQ@P|A!hojApJ7h^C%u%a!(8qU^E*r2Ch7rJcSs+=6QwKLl34?uF4636Vp!hoN?}zzTV#9?l8y%x@Uhg|-OW6E-BQD8wGr$o9doKotOkCg^T|6-a4+5vLmGri}MRsWEG|DOWK}+?2Vij zXM9#j$(Yiwj(lgRe>LiWk5u>59=D|!zEfS8%ErpJTIN$0S5TXv8y2GhbZ)9f#6ie> z8p@FPP-(wdQM}@wS>|^5oZB76&PLw8!aDA-cc*_j|4p1Ld{uJDJlesK>Rq}n`d!9b zCN|)9P7o72zz|T3RW- z6WU1=q~p?k;AZcT<_o8|?Y=?onvV0uQhxi~E}5HCYo<0Svpa7(FXr9KKTujw66h#$ zZ${H7iS4Hev4mKMLH}RWg;;4|-ML_dWLWnbEP7d;LB?PX>=Go??@%7YiSMjm1KNXf zkg+q!z1JJXFM<=|HYrd#C0fOJC4pMcW@;&8viYy&kNKr-zAZDLU(hFDj@@ZS^O0(dx0d@#<)lecxOh|;D4YddkjNk8TXGYDy&3|FfC2sp z_o&i6g@3X~WV+Kfq`b&@;aK2yy2cljlq3|HOLFZ;JT~-_tgqXszilzE;71pU;I6juwWJ{G55tP3)4(2s-417e$M`>XD0u9^vJp}CnZgY}^l=_F5! ztAr@Ny+q5;cfm~r}V3gi)Abxd#^D%E>_IUl*&yP|jnwIIu> zJL}?&CSb`PWE*IIsV(d_Ltk@O^J)Ec?Mr3^-dHXWajBnhT&N~xNFm~D-z0Zem*QC9 z+~mT&XZ&k<@U((mN|@3O>rIBU+q7QefPjqP(9nA!Z9_%|PYFH~bTuRuQl4*z`OFDa&!XUGrTn-5b3P1y+8HpQT%>JFqb-2GjukfH8be z8)1kyE(4y_EWKWHKs6l_fZLJlfu}MVPAc2+NocElP%PlPas~cQpt5@}Psalx*-b-_ zWp=5is{3i$=@JYL%mJ34=HKQereTm!6k{Bv-=Vt%NxCnz$?Obf2R)LSM{LBGV$ERB z+YqY@`^WcakvIli(?X=hoZZ`8SRvH$)%9}jnZ8M`6Lz~js-!eOv+#PMrf^FB*@Dud z2H?oOuHa{JJ;*dU?r7v`;Qzrj2WIpLxufzBOT&jzjhRMjqo$m8p|(I`#oQ&eVEJ?1T$LS<6bC=Z<1evw7gAo@9+>1C+t=pj>gK3koM0u@a#Q;#X6 z-!O+&Q`w)Yx{yCVjeQPD6}b#Y>dACqm&HH^$!q!@s)+2;cO_jJA?AZ`^)vn(^!`H8 z;XUae=U>hR`d0YQc=f&+T(R#n-;J;0%jNq2FJaVo+J7E0H9q;HxdYq}Zlz}cH{UTF+Wjc}amOk5C)o|QkE+^NOn1ay_S0W zeh3fcbd)OxdNBxA3GsJ!6UGDNabJdSh$fZ2>TfNiGGo|f;Ht5c+^?=D{E&aKEyTZ6 zH-3irT3yq3N*Ig3$0gCN)PddmNaZHB)c*y`hZA>us4m5ktI;*Qx<7yvwK&mT`U;iX zDM}T26}1YDfi$gXw!GAkFsQT935DWgRBusbu{AeKwUu7XUExOJ2=>W0V0{?{?%8H< znJ|u=Ac2@<)jjl$x6s zLL9dMkCEOAp&CdN7c8b}e318yb{%97jKf+>eM}j24>41Iq%~1`6eW$L|Ej-YFYq#~ zA~ROmM}Olti+e4B-XzEcpr|fNG?ve0qvP~6?-a@lSMsme5Y?k6a+?i{_|5(aDv3>U za`G0YgV0QB!dcl&{Hwn<;i0-KBhgF01GzMlfUMt>IV)#toF0nVs#yU(uR+8@@IN=8 zG-4VSp!C67>qFfV7pW7Zp1wP13L1=g1sh(DKSPAcE-qH=NbG_1R+ zNfE*+@ZSuUF4DK%F(?w7O-J|+Vt?81koqu>xJ_Tj=lKq({a7vO0Ch{P^3NtpXuWBWqQHqo+mPqW7@2hx&xv9VHuOjtgLz!ui=~_yR!{*C>7)aj{C3%si zqVEz`$gjZnpx)9=ES!76eo*R3?RYD$pkmn~)W;mY3`hcsmSlGpIRk4*)pakz%jx@L zKZ|>4rV)#{sRB#c6_>xZZlv3Wf2J=~=6)-ofBOl6mM z5PGkk=U5{|Y6r=w_6$g_y6rcMFW5$&YE&(06d5P!gzNf3L_3V2n=pEc_?m|4%qi?0 zenBzneSS(kLaZz*pOdJg9E2itBU zqqrS!r5>dEApA=_q-xU+%^z13-Ox7D(H{H3v{WeiDgIvWpej``mGe}iT5tE!s0DBM7%Y3_p0Q5%mpIIcfq(`34_YSQge_e z2kinEqt$2?RY30KCW#^)CO;;|3;*C}iErQtdrj5dJ5$%X>=0)*JjIv#X46U1R^<_u z@2)S$Q;UQn&?d?57n*5M36g#1Z2cj{e4(a4dt4X{necvbE+nwb@pRM9)$MmTWUXw9 za#D1oTLz!%jI%DVNPbYKecNyY^;~(!Xr$3dQK@w*r%D(fFi`y9I-)Kj`^(n|uTP>R z^=mKZMV4Jm1%8$isDCE5mq*jDg;dm1H^{$KJ&YW!Y-BciE*e(*%LB)ymNC&ghFPj! zEB6zwQ!f|~s_kyitkzUT7wwNU#$a1vU+?kACEm&Qz07TNNp6p6e9=sh{(4Cy_ukTn zyi&#pNddFvEV0z>vbNVn6kR2{1lIREeKSobnL)n(n#SaN^-J^(TO-Dpcy5J1jQ*wB zC;#HslX2FY-YQ}dEodv1WU5x_rl50Nyw*UTC$jusm^Rb{sU~Rq#xV&{$GSjP7i_u( z!Um{Ge5JH47PDVO@GD^H42HBchF)pectZ^R#|HN|`1X$`^! zq=hn0Ke;op06&1WL(8=xSU!J3y&46nSDA1Vg(t&+QjZb~nHF?$!qz@6$EYC~@q8!0lX ztZMFhKvh-hF^bF?+^i|zidg*IGSEKB>-aS&$1m=2l{7n3t_BDUp{{-nZ zW4Q!wJ>2d8`qJfts@~q~EGt}4=J~d2wh4dvXZSECNvuY0LTkZebQAV~yofbscFGe` zB(WO1h^`A;==<^`)oEol87@>dJQeCo_pxiL82JugXqZI4_9baL(I)AN{Df8j#mp>D zX6=SNZdu9mfG^~1={4$$Um(^CgfyI7B#@f^;9<}ToWEuM09g~&RdwN=guph&Nly(q z--zUc?oeuqb(y!bd$L6@T@mNdd#G1R7P(Ow@1G;~vsa_-?y8GyemWJ;C9VDmF&h^q>f%T{rn z;VNoZ6LUTDmAx8<4k_y;glJ-caXSWFalWy$(OBna)n6*tH=o>%Z4}nyz0@7VIPXWQ zn05Ix;qLPXYX&Rz7Pj^IL2Pnu3#msjGu$Q8m8^iib& z|E~b-FnNy1lEy+!IG&s)|4ZOlZE)f$#7}7ogyGzNauVdP?h;3e3#e*puT)(sRF3LW z(N5`?ybRr8B2g{j5k7-%C4ZNbPDa&;$ZSDQ3c;3*_2XbppvD#;&ecSzmmh0E7B$<7}yTYA@8mb5m1dgDZY|R zFj3hkG?HeaeacO-6PAeD$u9wKTC4PYp*cq3$VTH(&=v6HBOP3`e~J?#&sD!^hzrm0tKLay-5REy5LC zL)|2g;UUB={5Cm{yo5)C&LR*TV;SlnrYhZ%lBv_Aog9PjCZec8_##+WNURl_grC4p zqcj-@Ok^GM$ra>@QU|CepOmjioK#z(B@0wm%SoNYC~OGqZ)wQ}*@7#fX4DN0lG~zC z=z$}!W3(y`a3IK2u7En?p7IKlAosC~vH-JSN8}dh79Ne=f&FS1dAJgRO@eCiAaoVI zhLq3lfREdhwpbq2U-zKCN*Bx}7s=lc2=IWrfSt`3 zd6o1N@Frf;D%(&W>5Ft-E{8tLZIwInZfSye8M`HC0y1|H>fL!#qF7$JuDleR%Pa~4 zPS`DQFpidIOB3WP@*HWJv|YZW43XC;hvehRLV2=s6;vKi6&+fq^g(~*CTJ|$0^Eev ziXDOT1N8I?(65lliDqK^Q4t!6qOl#=Wb8LouLJSxfZM%81{mSxFdsSvyYgPxV5sxn z!p16%l<8;;b_}reTi6EdIbcq$u_b^*-i5lv7QkTCSasm&_JMb+k7MWz))rig=VHOw zIj9v6#2x{LwHqqdyU-wLlSDKPwB$W7E7lt82RLCms)4qmB&cV%Ks~WufOJ9K5^ygO z#@-0LH5LZw{ZFh9TB{7fu42_-D!;2k2+8u;xW=p=d$2*D^c4pK#NEFV3BE7|~D0TR%3a8rh!c!S=-IE;oiX#s6C z90g%@QC&0$>idE+OsRxDhS4zrwZRU+Z0rutNXdNI+Y4H6B>J5D&ffoD~ za8Dn=p1eNbO~KdztQmF}$;gEXC>1z9MrfSba) zbPlpJCt(u0gCzr!rNz>*fxx-Ahm`<^aRaLDZ}7F)UrdL$#OL5ycmbeek04R;DxQp= z#U1z@A{$=_P9?#_Xrc?=3sA|u;0!SfcjA$NFZKXMP(MJ)>VXFS0NRRWV{ZUWVjwlt zf<6KU(gvM_{{E~4p}+DEK$0H8yx1crKwaX6{0T7s%Wy`#B%f3mK%dSi!i{7LC(@m9D{00De6c4}hf0+2l;(4G~QpCQp$AsHfCUswu@l z4a7#Rf}FpD6x^MtX;c(~?M(=tGAAIXO$qqG$qA;S69 zoW}pa8{?(D?_6V@fzFMNb!9io(jD!cGn{`Ni=0Ky$F9e&eD^AEs&AjKIiyOffdc?X zZ0sFT5ftN1RA<>xP^w?k&ecyf1RJlI9+@{<##?t;tJ{v-8rq(MpWa;Sc}tL`g*nQ6 z!Sv9mGk!KW^gVT5wKw6h@d*BAvEACym zvG8zFT+yP!%>{i6*uv_C7xQED`NH6mKSd3T`jl36)N+4yJ3YJnEyRNoc(a0E%?v!B z7(#WS%Y*wtM|DT2)*RK2)34Jn)(3)DaVJ9?;{sD{Q@pX2>57@KGLDW^itjjXGg&$51HZ^|moJeC!eyECtS z-hzBz(aW-tt|6{%o+R$R_!9O2bpQ>!i;p8tkin2(d=yj@v8o{OG0A5)tGlYhz@fy$ zt^lut5$YQ1uk0xGF?B`FWbJd^9Q_sjQ$wt2w)uy-v&9T*gSFP@R>mgSP6o6J_-N~7 zJ8O-#zO#HbCm3rOzUYFr9n_Unh}u9BkQx&Pu1J~S1wW#UE_qfsF@IOCCx^-#ogbOs zG&eSR0d!la0?)JspKEX zE*(wb^k{k-T}J1DmVKP6Kt-ybK}~O&dV#u}x*dG(9qKvi)tU>sdxmi1ALB}MtaYPp zwe7F1eZcd82?1Mek8O7YQUmG*G_^Ih4Y5@M9j)6m*vJ^B=@`vy)mZu@SrL@K55(p? z?;GIh=h#s)r*Lfk(cEe||K72bIht6A?8=Py)J6y<>WZT2a zS~>4|w)oBb2mXxM8Q1~60iUf0`219C7uFE(2`P}Z$v)I~>H$3qYKl3~(nHm~)FU;? znwy}-->Z={6|@eWVyI%OVG4r^x!JbHmT6lPa4F!j?E;MUZPsB{vvs`XmYFeEH@!FZ zG|GBX7pa-a9$=bKtq4CVkl%=n_(b1C&lRV&EUS1(!SGybPXFwa*>djee04sS+dJn- zPE793oN+n*a#M3p=RSa++ly#>*RsQ96P=>xo4Xc2LW45ijl zd#LjC9cG9sS0%8wHBYp)bkRdby-Z8YZ7i(yfHli1 zTc5zVDQ78X-ebx(UV*gw&-!%T7j3-eFgSf2U}IH3=~>iCq8KAlWBH6&U%0}Z@zwM6 zbd7QhExlhDl)oz1k$pKUC3|B2$HEPTd|rHB&4P~w%VAyVTiBpTDHu@@Rg_x##qr)b z#C6rPh}$H*70pUbtOb4@FC>;j=I(a-IQzED#-pX6l!_M4a+0r!x2?OUb6(kjlB~irdF`?%{k`$0&2J&$d&0oq zty3ej9~C?-S>SBOUqIK$%8XjW>tb|M*}?Q_vIsoKdiyIlhuVdbz> z@}6v{5T*BAKU%MvMPBj&5a7^QC z;1>O;z&&B-!p29OEY~V}XTV2x75NlbOH&-n^Uvp|=ReP9@|*>|95wx0`GKMjh0{OS zTIxaIbZXMP&@Iy+Gi)@?v&{rwfj2=#f%5`-TW07!Q$OUbey8I}VOCDX>`~e2IsfLT zmk92$;!v`>F57AepH;3>MSo0}ivIGY5$^+i#;uwy2|N;u0qb` z7nRnq!ct*gY3-xZcBe$8tN+qT(a9-!EuG)o9lS^6 zW$Yzg2i+*mX+3A!Wg$)djMpq*17ib|EY|>GX`uhatdgU>eadPS4b3{9qWxn@%}Rfh zwK#vVW1Tpkz71Y@F%iRJuqwXF*_BsUCMrfpI&24kmwbsB$v<;HbLKglIt!freWcu+ z+@VU;xvc+%_6*+=<_jGj+9l|&X^BG<1HT|qlgtDQqATdrc>$;N@M?h=Z(_VIZR4? z(zVnJ;EjB}=&EN1+1$7^uyI8HiWh3`uhpf7rfSnFpJTk?T3b6^mTCjBPLlk?eVcuY zy*}?E{x{l0Hevc`l1ycRcS7cbHVZdIJ_#Qlbj!R$KUO`MtS`><>@7Q5vMT>V#>Hgf zw$ivA9eds9BcFJsIMe0fr_*O_*Fq}=y$D}7y!gSaA8 zgR3(a&&X?$)$wnqj5JyaY0 zvfQeusEEJeX(8XNdhKrNC2A>bcFU!6i<=g!3nOypW|pS!O^Z%i5%=NKfsX?cx+fpV zteS80UsNl>f#us&4Xqc`{6veD4Y!pa6ll;+#Tt~Lq>*3bk0(BV`{4PoF5a3gIR5cJ z#tB1k$l$Qy;F+)>0uO5z&<50P9n1?1R?S284NY6^e$`zP!@A11d561a$;rIJTx(wU zoYJh#dClzYe52)?kfJ##fCzsVdAD3fg>Et3D)fyiA2vJiz4?UhD>H+fiv5s}iL`tO z8Hp)$P4#O%Yi$>NFDyBtG`ePm>J>Idt0J5sNQU<`#+^A z#d{0gd4sa=rHQ{&zfbv;^0wo`VU+)H|{ofP&s##2pIw_>A*jkZ*~6+Bs& zL&kI3yzjq<#m$Pplu#>f)6euw*xm6XkcT~GYaDt#L>>GrAlLfESfV~g&!$?_HL0fL z7|=pxRd~uL^_e(5raV6UdzvM+`C8le$vkRkLHf|y|U3)XoA2F<4yGnAEj#aP4 z)Qmn9emKZ$Mw&{hhm@NfjQ@xJ;(mGxI{{qDy6Vmvr<$r7OHBi<(*pdq(bgT-m$n33 zSMxHxPxF?2K)1m|<-7b_|4DCc*O0R8(hDVZ3Onb_Ne@Wjes}#b{Y(9iOFtcp4^M8K z)1`EaxLBOa@dC9>^pMP27rhlvTt5Igl!u|G{ z{_5m$U7jhydM#jHU`&A8!s**+H>gLk>zEt#5vD%70GvB}s>U)?=&Sf$p_li&6Dt!6 zFXk@JU07fR@6;W>G^LKZt3@B|4*fib}Q$Z=zzu8wU0_)KIl0S}Sq~wT1Jm$dLw}~n@BjM#t2VJi&c>2F*KoPHMrC^*tZ1$DVI0?WU1|5N5G-BEVjndmy=!{q6dsJ>yC zXuA?J51b>Ohu;g`82q2@y&*(1in&KD07ObBm6x|GdoY}Or5dK2XNtExw#5cvp}Np` z@X1LCdJ;I$*3&%L@J=^g6RDa<9)T0}5@j+dF_s7?AoC}V9}CQ_3%<6V-eqHomgH5> z@}~^=mGEQfuYrHkQzJ5k!aJUtij#SuKVnM{8&Q6Eg@a*FZKbAZhMwdo_l5jgIqNcs zq)%}_6Y8do%~Iv;C`SGivZ30q6HT`RW(2LXeKvhGR5t9@B2^B>k%s~8-oad?LYPQ6 z!4y#Wm;>!ph6zjDf9yX?gX~{R<`p-vA9PRV@5$?lAoX@*2U}!d=b);=BZC(OzO?K& z_R~MsM5}_SF60=plsJTUA-2%%)Q>c4wVmMjH#lHPz|VjY!N#yHVe>f%sJ7~&k6^GK7^c`)bp})(mtdJYtB4~xRy16}5hbt?6P&6>FYRZS7_QVP4oic8v zyYf<8m!xONBJG}%!*-*ykL!c`u>XQ|0T>^f)r<9tNw!uBSQhvp@Oa>1TSrR^Q>-Br ze1!Tj<7tr^N7bY}G^I}0_A|()p|*^m$Dz9LTj8T4Iz;RV+ZTK(Al7oga6{dYnuAV> z^*Q8T>^No5E46`&F|N2l>GHBBjumB%$~Kkd*i(zU7j(#Nk+tFP-{fn_D^im)_T<7QTr)ryXvIR$2gREs3pa z^^ecahbX_W0OBrrne0tn2IWK$(H0$7{>66UE0v*gN2DSsd@`{46M?%R%EQE5!6Dq? z#(E!m@;u)>K6e6eMcNBLFev-OT?m60;6 z0v=Qs%`w>h_toZVwrZ#9@^w+Vr`k{*t21f$s@-e>yNzwFIzn6McT_WKAK?KVNoU|; zv2sOmi116eB{+a@)WGNR&hhMUzO?^aI;*%rL96UV>8n!h$rX}sC3pGzBs-*_Lt#|O zD7Ri54^{9~`mgqkuBZAqyGC12dyslAnS>y2m2+1inYTZ;YqmY3cV^|BdiiaN7TAZl zTKY8nX2?jtf%||L)r=yj#(+{45+lI5VLM(x^aZU&ThOX3r2EiYf&HQ2pRhK{XyEY; z7G85H{&3F0u}~p0D^PpGFOsL}BIcH=4tpLPO_#G8aMYX$b&MaHRk}RA(SY=v;id7E zalb)th&7Bg%+_0V3$?Sr)3K`NqxvVD!kej|uuI`D>$78l(*|I8=(qW#T%l{~{-FB}egixRf5aqhCLi5dCnu8jWKqjRQYTXL%8 z-zq9D?p-Q4mUuRMJ)Twmd*U~Fw6tDqD8H1)iD_IsSC_Bq|JU2bo9z4S-^GOqcSXJI zm(QYLycg(?PT}*32s{K8g2%>t>^W33n-L+@GO7VRQngtfqPefeHJdddnltQ7)obPm z&5?%)2J|op)D9A{k$gxk1)s=9Y$BTtE5lsPDRm=Y%=+Qn_Y=;Oq0ANf80b6{x})lz z>H<4oGe_4$-(G)KKg=-NFhV~{cR_ncb3r|xJ*!G(>NDr*Idld3HZ_K-O;OYv@*%mJ zECKaHEl4|Ef)sh2G)>F`wapE_Dc=ONJi+`PF3q3hJLt`Ezjj8KRVi^5aJfO*;hAze zlX*0|OI~=sFRy>W_~Lh^&r3g)I_w!`<;&LFKa_28K6D;*n4KfsmArGk{d{$JTAB+i z_J6SF#3d?-`N^zcmeW!6NSN~{Ku1y^lJ|B{S=3u*ggQ~vUwc$nVCZGinGP5ejm=Gi zjK}rwwBt1c)YaHS%mFG4_|h+liJ(N9f=?nYlMz%HwSrm=-go0b0fsaC!Php2ngP4O z)zFp;@j_5Rd;=9kFrJ2m03&@Qz86o2M;dgcn$$xQ#9h*AIT$MRg8^?cVmhoNJdZ1t za=q$GZF8wHC3blaIvrW{1<7;+7HE*k~yRN-XtIyJx=$}A!(ry@P>||VM zOff3P+b}z>!=svUyy1s_fPSHFiFU2#lG*{Tx>eay)iQ9LE@F7*1=AZwnwB0##ey2} zIaKr4DK+G);$&e1e~(M|8~vAkw|u$a!93Fc&_9n$0|)7oVoT71K2p9x?fnM&hW26! z(54Tt)zE_7h=asB@+379lwzl83$vIR&WvRqGSyY(Rdv7vI#MMtcIG1J$=cCzpaV=L zInbGG0C$6Opv+4G?OT)rI52n*Ye5h30C>`ir6!Qb`~b9Xb%YlDf1Jht#hdD}dU9N} ztG4@_x02uBvw52Wk21(#9-?1QbJMwn{{37E@VyY=PNBqYsw3^AIaMO~Og_@=)3(yp z(1qx}K$WEjs3RL_*K1|Sc9;*Is#?u$V7R+9J2i3YN8pWjP`!%nqM8X~=ojT9GYC)& z0q(IMYXjKbHk1Xp(O+~A(1*6@5~vaS%WtF{V5^)0g<2`UmtV^t;4kxE_-J9Iuo8Hf z4Me9{UQ)!#Qdij{cUQ8M)94&_4-Y5nQRSctSE$MWzr%d>d(8sy)jp^FrbY0$pijvmLZ&=`jC0GaB&Ot^}@PV^wL)8<G&r#i(xQH7~zYZEjhfjL?MT&HBI3cthu-IwY;=aJo~J*#|+xN7_u z{*JIyvdYz@2jVy2rJhtyD<{xud7R0D=r*+Ct93Fi0+cCqR)t$?DfBYB$m1iB~(9u0}iGH5_(tWdFm?{a6c z7G%I*6IO{HX*MW1*C8EfROexzLF1+XK0a6eDmld@(I)){mby)-%m;9UuYWqwlLsP#2%5xux#N9tLIuWZC1>0V%!-DB2ioJenpi2E`C3{Sq6B zqohH~dW1dzw|kfELm zcdtp{wBH+#r6CFjPna zwZwP8TxVf{L^kMTP2ih+5icVC15P6XO-p0o@nnFSXec%fJq7*HS#wG(|)&bOo zui&m{0>$BKaxv*3!|Agy?#`-iFiq)wkhS_y36}k$Mnpm#(0=|4=>0p?4AoLnr4r#Q z|A2eS{pL4{739N;6>wWS;MhURI5C)C>34f~LAu9cs8;zL4)-f=EFjGHAdz4@9nXAZ z_Nz97<7hZkS|@|E+61aSxg7tA&d8g@dt66O<|CynDNAmQ1(0p%m&{ky4Ym#oNztmU z>}vL*Dw#>4RdiKqCGi{$m2FZ}v5Mg5c5q$zzCu&6mQ+bTA*X`hejoG_@W?pS3DJlH z4RU3rjl4n%kqX2W;%}jj@QB~Y$MV3V=6Caf!VY1%7%bleoH!MWB!+>{AWnS({ElVy z#)05zIuHCsb5$C)5p@KA4ycL?UxA$?>JtUH4qr*CsbXpoD0|DR8f#*;-PKpX^L?W# zmYG9!C$|7Xc^0%pE6HNGFFwLHVok7{kdB=RJst<%g`s#3HXXEEMM#fg<;S90ijXV(#@q0lPB>vpfxU(aOtoZC_R*3|G$=Y2PjC-$uywc2eCxl4R@e4Q1(9}9^xuc zhE>DrVeMhZm5DvXQa~H%kw?mXL224q*@AMh!*~@Uj$lcJm_u|0-t844mdqvZLV`pc z`X8z-IiJ`Avtt@45W`O#%EmNX>qCE&F{my{ol~N7K-h`h>qza*3F9Sbkj=R%_GSP#4O7?=}hwk)7(5!EyJJRQvEK2%WPTs8Y4(B_&@9+xF2_}bo47;P zr^-MJibiC;o5K@mLyl%qWqGk85{LG9aDj*w@{vA_xVgJ*#1{4H4>lomen8DR0p ziJhRXo=yD5?*rRor?N-RmYPW!;yIC)R!fg1hg4IZCg;e{K+mYcF;MMJhm7i*L=WOE za3ueLnw+z<(ny z!EB1b$6{%))@@ZL0()VyYy|bXPcq3JjG!7+GIG@hF%Mg8{m(yfyyRF zkSB<9pvSBPte91B!fOuMCR1dW{69|D4`nI1OkBYBfj_|nqAuB@g9)MU4^Vi(r{L5 z0)7eI0STKzG$oD@*FZ5K;(x%+<1+pRJTx}r`|yMKSNsX++DGAwKHmUH$UNLey#2qH?gaj( zs{0@KIoDUDgwjA|CMq&ShB8KmieyfrWC(@o2}LM`C_{>(=t)sRB@`MB6&3O{pb`lU zN|E87bAIo0|Nqx{opbNqXOC;Iz4lt4z4y8}^D48!Qd;)adAsuk{7nqZ8%n5YwAPv5TUZL0(L&;+bdVSP*ABCePLkiqr@be^*Ag#>6=LE--J;}yoff!?w8SZ+?h4wkUntbSFbo=0TyWOT3$?pYY zr5qX9Ln!^5yc_fm+68S%*8)MPK34`!f<8PRj0>jOsdiIvpM7%umC#3RQ%&_h!Rx`t zQcaQ1t>Sj$A)rNYq4NJ!(s29DIwe8!hW~8(+&-kA@kTLW)19*(=hXZH@n*Q@B<~wFc}CU)X3xjbWHz5JH>u)L19j@Dq_%wZGz*(J7v4%~4UCAAe%~K14Brp8 zh2Qhf@E1Qe&xK>OqqBdz>H*v6sD7pR&iTX1%ja>yJSW}9D-B;mu(kzUhQ-_wfp_MxPE$M(|NF(LTq4X3d^FE41XcnU(k8}ANlh5*AA59BncBZ{+)&Bm_E(0P*TME&+Kj+V?G=f z4>u!r%$V>6fza;GE@yonCw3l{( z`LwHihVnJi-QL&5%05b?EN<*QZ|~w-`v>Qdes3o4Y1=Ewi|qGtJVIP_gz}Wnc5!jc^s)_4=!rjYB)t31EuXvvlKeBswp?$CG?ZG_2Pf|sALMJqS zQf?EHhwNw_h1#>sw~rcwv-m7p%qPm$WQ#qqOYPJAp3K;T@+Iwotj`NTc^)jX`eLV= zy+w9xNIprn>z#z3gW^F6dhBtdD$i$;;9R>NZ_wAB)$L;bhqkKy3}rl})ZR&Vdp^67 zTAh-v$t}5iC~p``joPQ(^GUIk3M0hnU`KI#ZL06g{<-?(3au#aD(814yI0HdxN)8P zE;QPZv_{T?(tBg}R zEBPY6BlAT5RrwmRT+Yw?hySA_@Av%I=c^tbh5*(H2GLSChi3!_k^#K$d>9nuhxy|C zjq-gT9`pZfFWdP&i+%7TKPL4v(}L1u#gVYL+4GrvuhFo6;MHXz|Gj)d=9S1hC-1&| zJB;O9jq{&_&%&2^npsTdp5STWR8l!8!Z$?OuyNQoINdJTO{DTv|K~G1%p(P6hSm5^ zvsM|d2}%Y7%|gq9J;|HANR$e`i!1QT63{(Mpi<$a9zQ61$*~XUV*BvGHeOE77WRDf zbM&2^>bvO@OXI_2UAy?Ic%`w`2|r9RQ@_N=#4Ys8P4@eaj{7G!#m_{=Xv6#Q@C(WB z_A0+gyBkNlP5zM5? z4tOK8)BB`RHjPzKk>>sf-T$`aa(vq?7#dv9FUJ(~-!WXCg)=k^9|^7^v)_lvj4}rA z2wKvEdXxYCf^TSm1Nb%>kLy1R7Kih~7lW?G+U}$ko;*8TVeVKMWX&N(jQTUk>TiM> z_WZv?j(x~u$!~0v=luixLuBJHd|DWwhV_ctS6oBYb}%j)C}lC_ul?yBI* zrcaJL{(n4Xkl8O5eA%cRRn_RwV1~ z?CxovD-kS=ALR#VjB^y_^Rq zv+dFJxUI3;hg=wxJOruU6yGYJ&hdjh=hU)-)!TmTx9l$7Nh{wj&*RXClK6B`QX=^t zFV5xJrc``={5VN&0iX8zTt#b~4Q~p6sdYd2qGK%B>5p@1;&PdE)pyTYxTGA}SWG&rZws&7Ph;eyTaIL+#}7qO!l? zsiT&>f6tDIra=v^k4i_UvW=nBRiobi|8_3_OK(LJSqDM%cXpt*HdjiN&GOInj6815 zF0eN6Z}w!ic65vO7SQ_Bo$aJzr68>4mdU{TQ_MM)vM-K8Ws(YDZhMFJ$-n$G)e~=B;SD9j|9Z z+r+AbezW=H$|IR(^V?N4s;upcAp;+DKUo^hz{ZOES@4*`s`!7EscS z(Ma)Y`1WshA4;C3t&^c))^OnQFkFl;iM;j`WJ#`I_t&pp}bGXI-?@~Hr0qVd3=!+`ujJ>{$C8s3waoGr z$*0mmA?r^kY071tVz`-p+1k04YuFpLNW~5$Xah2_yQ?-0WGp#3f^=TOr{bz$CVl@= z=-vcY*xh9EDB;ss0yol8$C37T_?$vNp2Mzsn8rR!XkYJJ1s%!o+T?CqI($#1_NGhT z>)Bhp0YMuvQom5G=#`K10d}X{$H|Q+PxeQDb!tP_mmte4@k>@o%^UC%R>?`5%Y$OB zs#;RrY+6~^x#lK{F;8jd2Un1y=W1Egpaw1F5^sy6Kqc43PCAv4&qZ>m#@gyHW)+BE z9dR#XQMFW`KICxc;5Pd1LYCQ5d!SSM>_e8oyAZ*T*g|jSY^v8;2CH%*+i2$dUAuz} zpC$VP=SQ6RCDTA88h6q9j|AUBH+R9Teh7}T7qaHd0_L&&tliWeI8)qWY=VlO#pKg~?tgyH z{w?TP&>8t7!6Ely_Cq%Ki%+h#!6#<4ceHUi{e2@%|0yj?eQ55H<5y0woDu%!`#0i# z!;9;77QuS<$R|#>Owz~meZHv27r3YBzj=NwHcv0g<+$2>{9n0kWkbD$h8vZ#)|~z$ z%i?g_O>Cor+IC!B3WQa|v&?db)a?{&;tST!cGO;FpR~**W<^;+{K<> zDf~0-IH08mg$Cgd+I=WE4VqFeyh#5vpmS8o*_oyESxH(;NgQ&VK2geyK#tGg@yId$IoCsD&wK?3H$vv^#um#`hoV&n2kxuK(0}8#Pk9 zIMrcS^XHdRe6>=Js9ApfdQ6`0%73F=(l(rlPt#f#2^0N$*!xtcDH#6k{Z4%Sg}Uy+ zZ;5w5X~*YMZV~#0Jik`M-`SXZaK!)g!$H4ys?pcprLZdN z$ms19bRmU$kVE~9=mEH91Wp)+dnZdV()ApfHDAAvW8L0w-1mdPG$RG?Ao=dF2GEu~ zxSGtloct&h6m*})7COOlDkG#Q1f-6)mw10UjHwZM*p}7S08iaPlHEr7_4MsNzlY<+ z`>c@kBVVs4cbbxlH@lii)tdafmIQ9+O=lrp;5I$+Y3Ce|567duwdWo=4O6e7+Ba5S z<5}mAL6Z2|v1i zasPm}hj8&xax6~|*Vxb$m%h1tI?Lr+rCg^MtN4@fpkfO$0LF1?}?kp6Zz`Mt_Za@Cpf-1X=pI?06uf03*`F0^YjF*GP=D%oD zjBjIr;{mN!4xi8KXhZ>jU zCAGZuz-wT`oip7tz2VPyWIFhDowJEIShs27v!UOeG96*fErm4CwDqQ4<|e7G5zYkm2cNE$_bDMldSU>EN8Ju~*D#KvjsC!7$GE-OT8#1k|qg^ZRJxc4vPijA&RPWEs z;7xXt&pyJY^A$T%j<>Pe`-|B$(^Y6QHEklaGc34_HdN1CqE(s9IV|ry*!F3j$5?VJ z_$l29=}M`yGF%KNn`)2Q98`M@_BA7X16IC}$H{Rp@KM%Y#^Sum==-EuIh~`YiT@-U zeV#Z^>Vp}6KPG0{Fa7jg7yOpmqJ8w=?RtzCNxd{qTqlskPScFWt3z<_AlHAqouK4l z;k}+yb3M2ydS_A%j7*S8gcbJoG?%|<=M%F)=E6(`G^*?>=dOx=H8Yp_tj-s1&CC_< z#`>_4-&HfUeKzA)@@DFByUOr^T{2Tdxkqv+NkQK;nSXFz zQKjZnqhg*>crg4W{MGYcUMTI9#45TdV{omG~PU6O6VZO{?+PEDLZ$^tvsPdr| ziVfE3-qCXl&|-NoWBBe`B3g6?;* z^Xroc)giQ(kXjXKTBVHBi%7JyNQASD+k9kOV%{odoR>9<^M}XC$o=N6KS_Wi=JSGN z#wouzV<$DfH5YCtAJSR>OS9HTbc9W2z-4Bn|H4lf(M6u4AI&C()9Z0k{C<*oteI!A zcpsR%UUjE4*wb{WNu>BpTFMYw%Fvvy^_1rfQhb)Q<4E=y=Hlhl`FczLAG1O?b4n-AyUY;-g}0#d*KpS{yVZf@T}ETBW;QDdeX4AB>L^ZM zbI=Xusg~}#t|mc4u~XFVV*i3u#a%oo=UIeCovs9w@H+$3NYSlIbn(;%*}(?c0`c7B z-s0Nfb1Mw^Q}<3djTh~ZpmPXYq}OSm_5|f+48OsskW@BT24D($#SX#?sICJlsL<(Yn~3U{qyP3!a|a#B9aidzRg{uvsxoelSr zRfY9d#9oIouYrUtW#O%e--|zhu)Gdqo()fa&6>kBh}*dM2{_0U9)_QFwtgB6WmG)D zHHCHfro7fdSJq1RcDzz<-$9rEgejkboa}}4zRE-Kbg15Qyi+`{jI^{R@?Pm$Bjp;U zuI4#-kN7(wML+QQyOmGjjq0*NotLW5GH3P|iS;22c>{!fvox>zwHvNL)@Dzy<$Q`Z4W%g}2VJYWz6s^=a1v?Ow!&eT>iKUheT)^t$>iSI6P(-GR{bq1rdd zDsFpN{MB(ySY#{KaG%`U`{4DDia8Q;v{wCJk$SmOw`#?5zu(r5H1tL7eMGGLVD&@A zAEg$fcuel^Ih}nySBy7MZV{?HY6ajSCHJ!~Fwd1@oOe3)-kJaKyVYcv@ELkwl6!1C zRnI&j*30sk?p<%O`nzt0n)OiAR&iT74H0v4uGb!o2U}0*A+$@}O1xW;@|pbno(Yq!BBXnKvsgF99pV<~aHacZwdo_IZd^gSg0O_&qEk_U zIIzlIAU+x$w%ad1)U2FU{*wGH7LD^@3xBvGe%lNBoi9FVkHy*XSuluycr`i6V6ZtZ zx;U<-HkYbXD_)~pz~j36cC$5(wB~93Zq*0F;SrQ-ne#6k8s)TeliM;j&B#% z9S5XW7bT1kZva}RyBIp4_Z>>3~ zHsOQZH)(5WlrDfGadaM!*_GllXqVzMDLsFMju#f%C~kr>#nJO}`wHr#P1?T$gkLP| zY;-K_eH}Q?HG1s=AxHS!-W%?rgme%Tv&*BTa8!QBiB zS~sc|T?duwDMmB?|0_r3Ab@50zHbhf>jJl{FHI+Y@B72z9`S4no2&*2tLaT^Xj!+Y zJp^lz61%vEsKsxPt`{MN6O`RuzT?%kP4t8}J0Nj$wPb&^0*dz;8s>+|ZiUvJgl<)X zZ`H)7mGEgvesRjhjqySQ{Ctfywkqy4v^G9#iSs+5>xHnf5;&uuYpDLdTmSZwB28Nk z=hnp$LvhAroYURi6Sq8Rj88WXpLfkLG9M!odrI{H9-U9pyk?ZoCN-xSud|KcIb`Zf z#{H&v2br9vkT+wR_M1@c?6v^cLVg{+GdfbbuP52@2ptDNUq)yc+x`JR`Y;BEdq z{&Wi4lT}izb!zvMc>74i&AF1lR9EC7aEszgf@t!YaU*X4VhQw-e2%Q>=7;mh2b4 zLcA?j1~zz`;h{fWGpt4vC}GyWh~{t|>pHb#o3ml-+e=WD7IE0xP$BQjiC2f_a5+t( z6ifMHtBnn47kAp*)sxmTiAFM>Hk4{CgXu`CVGJ+PiymN0x2LO&W#KM>%Z#884Tsr0 z37>j{<+_Gmm0I(uK9wSeLs;*<={WV&DXgZYr6o<|+kKY&H?xILvIqBwJ4f2-P=~Krg$Ld5NVftm_C5P~ zmw11Az6Uwk%!>XgXE~m-PhgpLq$ofaI@$Cf4EMACo2XG=8dozK+GuruDu+|`Lz^qu z_@mUlpZfg=Hc*Yeb(45?QEEKRuN%~%ieI(i1Vi8nkEyLyFfp!1_s+D&v2?xrXoKyg zt%VjhOEpt_7Sk2S!vj(jFd}(UFp>qr6M;d47WO-L$1zecQ=G>ok^Si z9i}jceRGw!X^isxrtAo7vmtidFWod#hZV4tE!3SJ{%dd^&HE_SumhlQftAY&#zf(y z9DIDNb@C}vnW1iuYAk)nv%2FX0tWc`889V)=-2_Y`7O$ zB}-WpeeM6Kg%($Owm^$|(13Gb0jI+l&Sqa!hwxnlaVYD11)n8^7evA9U?sP*@NZGV z?d+0A*ob{ltg3kDu{p}X7Oq5%6h%HALh-A9$p&Ym@zET9c8U+8tK?CW9do6t0tBL} zx>b^QYc(kr?qZK@L;17BsKBmCQQr+{^Q*Q8LP`lcp4t8k+J8eoQgfRjjJ5 zIaFpV-h2__xJK`OYovW-WX~cAy5!`=OjgojHJ=RiYiU%pCtK!7u}NtQrF+FVoKA|& zFfNwLVX@RN)cF$1yUqyuUcM>T zo$fi?X3V6x`42|UulVd&J#`MMr#pyNDQk7^?amw~yPj3}G>X4vtfffuH)!#V9MibV za{2ru=hwx0%E+CVd!`8ey|BWnuA1b?4LKQiJNa_0@!kPC_#{3btHt+gePG5t-_tJp~^(EyuUsf&!!!t%`@EdjP2p1;=Mwi5bI+#UMzHo(%O({1JvPF z`HxhG`lxf6d8V3~DP1+HV!bEbZxkiPNyYl!H6Uxx_3thDSgghDf&14khl6ZS>)`BB zE>EqP)v&U3sm}uT>YWRt9jNSy@=rOf78%?^eePD92jQ|!lvU0AR)y_c8!44sHe2=&{tw*Uq%;5k literal 0 HcmV?d00001 diff --git a/docs/source/_static/kaldi-align/at.wav b/docs/source/_static/kaldi-align/at.wav new file mode 100644 index 0000000000000000000000000000000000000000..caad1178c0f209b7671192d3f79f03044fd22ebb GIT binary patch literal 2620 zcmWNSc~nz(7RTSdiiciGqSC;#lie#aeJ&uu@T5nIqzI)XIo$wJz8y zE^UP(E_EA6Mg^BqRB!Y|#vmW}de=9pCOXyZ0h z6@$ZLUQdmgu`6**680H8)fsnn%C^+g869c!rk{*nJ7WLnkyK4;KI*Iqd)8ZDRP(R9!RIp$75oMry?Utf;ES98?E0qR z+sojN<(FT)KGr_7z9s-n)pl7 z&M%S{ZA?y?R5Iz0q+PR5&fPxKI&I4&jd9`#Wr$HZfu5=`(hQu(P4a|UH@&ZITH8AJ zeM9?!wppf@j-l^PG}gbo-;((*>804X{?+38iu;C=)U#6YW%BWhIY%(h!eVUJAZZ6y3_qL&^p~k-Dm0;O@iu`d_A;B)*jmT(c>sA zA}PWW-l1t{{!le&CxnjrsBv&!=t%MzBEdnRQM@B8LF%Dzr3|i(OY*IfVx*4efXmiCtUjp@&`tHtuc_djX)_nS=(Ef0p= zUR7hPqN-liNUzb(gx<3565D0F*1FnJ$Q6>MOtUh{P#1n{@HBm?W~82G`vrTL^;FN=p5r}td*dvTmKEI-yFcl=*J15^(VOYm zXn*9)?I&G?bD!PhiS{1m-68~7;8L+laPVWKYnYDON$H7~@K(%5rV|C&4J3~oPtPXI zcm&o+tRo`P7$StisD%_wen(d+M#?WKH>(WlH1%qgS7X*>Y5u34rYTqVF=oYeW*@VO zK8M#JTe0hC0vN(KdM|o@=wIOgYz^HZ=GG4h9~wW1CRIyjOH6A`+vb)@E$7o~1y`=e%7*%%`9J3#b3Au~ z8_iqzMDaH<2+W78u|8xKlnke%PqE#Yk%*^fD_i9T#TN>n0%3k-5(e>xLqVn>o52?J zB)BJdMp#4mrNO7emV^WcM;pGYBSVj?0I6{?TB=#y!qUp!dM|!h?LZ@8J|cEr?1KmQ^U!X zlu5?Y&**;H5ZP`jovgxR@FLU$|BTE)e?s5F*TF_G4xA5UgR#&!xB$8f=D_*L5af6G zKKcuFgs!27(edA9N7D z3}yjQ;(I<=oWj57y~0P5PWm9VgV&%d$V%)G+KvkN2RfU{U?7H3LIJI0X=1f+w5PS> zwWWi;*4P<6vtKn@eNK6p@?ptXC^8oy1T~jEQ17biPjQymZuH!=3|LH^q3=^V!p#r6 zPIbnb$99%=Z|Q03`PN$R;QISqZ@h*6Jwh+PM!>X#da8&2t-YFBE=R54M}tI#N_sXC$w zFT#w77ibfo3oHE(yf-}C-9Pk~+H-BYY(pJgwz0O|R)y7NeQYOP#hy$b;#UisM3qzl zT!iJgM5IxGtWp*%vr-WKr_4pAk|jhk{u;>xmr7~E0{;*;%9}9&crLlNyULy6jy1Mr zwuO!`=MZO$GsCsgo#I*HS>WBqnfP^55?Bu9AhqamVzMkwu2tHVdsP{j16-%{TLZ{NUOZ(yJi3%$n#dR7q|+J z_ILYF2-BrFkbz5(ZgdfD!OsvC#2B)edLpZ!SvphROmCxa$o9w<%M0mP#R_?y>{vi| z`2K@C(JxOo9g47}%b z;G_Xs5AB8KK`>kbSHW8N7&I0B6v2_BFb5~1B02~?hipSG2da5k6@Gy@hc{q{v62AI zHMku&Te@N*PE2_n2VS-p9{v@S~ zqr`_|ob*a+6jw{F(hfigT#(`dT}Q-zAy0fJR)|f4I`HR}{3k*zZ{W}RxBH|0>70>k zVmmk}ke{Vf${A#fB(1?Pc#193S(<2-a5YK7hh_Q+^B4PFZ`fg=JGlt>Kn7`cb+LRh2_iAO2| wJK-@}i>lDoXe{7oXUjot@p;UUrx6MiBwQ!fvtq)z5zI?(P=F?s)C)z(xTD6-2sgfn~dM zcIMl2|Cuv!&U2r6?(f#`jvYE^(7J8_Fsbjvf%BKH3FH9)0D5FcZ_mY60Du40#Adtz}Mj-xEJj{jiODVXVZg_TSz~~H^ykxh%RJSFjKH> zYz(V}HHdwIZDlXRJMfvDi=1{2$w}sp;V$6*%bmoQA>88XQYI(>^yr4?z{X?U7O z^<(v5wN<585h{t=OU?D{QM`J+`k7j&Ija$9XKCxSGjvS-3;imC$}raW(%8fF!{l$C zZN6p3Eo&@)EyJt|>r>leJ7xdwc;-CmT13RU3*5WNxzr#40h_>^&=GhKZLCLE68Xb8 zhej|TGsk-L^kQ{qt!5o(eP{h(vDl$(Z#ILiWVNtBb{96l{>0knc}`&_k2GT^vDsJ@ z7K=rDWCr#Q`+|KIjS!8)AyzZ{2ImNm4d+`skCZn4X z965}HA$O6R$UM3bo&-F0C%FYVQ4le57pnwq&DAIe?InZc>!Ft}-) zJbV&97dPR3@E`1ftXL*Y{|3$_ClE&GUl&aDc6B(4?Fly0BC%XB_cfu`8SbLO}Y``;6bbfC)jFA@uO45r-mEqpn7;jOiQe727Z7Q`GT@ePKgF zh6HB@T?<44y!?E;w@4O?*9lYjUpcqf(^yd$oq3OuM?XXRAAAF>qo$K<-Cv2-uDuSU z^{%2UrS>;=rE zV%;>t$L)0Pw|vnywa;j|(loMpYI9cev=&M0)V8Z0%M-V?v@PrSq%oK)T`WjJn~Pj$ z%w;ZTyZKepAO4?$mxZw-HbwT2{vEd`er4jjByRHi#KlSJiLqU?*MbcVqt9;yyJS zTW_gjjfJKp^M3O_hn)J(1cWN-CLe!qFRu)7o_K&H+WU_8a_`^L)uIAUEK?1ipzgVE zIFlU{t*Z<*s=KYCruhvO4f~tuExqJNT5dF}8qwy=)?La6y4B`8_TEG{a3S)5b%uXX za@^k;F{#VLRCdqTJ*7Qw_I}>4)1dJ~S;HcS%o}j0_o)GXIAtpv6?3>xTWEb&Bo1eVyZ=Bi`xjT;@O>*B$SisK+`rKZmD=!a`1On0r%UQ4+FD+&o93h6wE0Yo89C6^$)b%^MzAu z0-N*O#wy<{0Zp>@ps9km&rA?Km!0xv3-|Hv@d)7|@0$TX0;~NlOIPyRFb*A{S>Zdh z-tYkGiw)GS>=@A+*_zw7v!lDJM6*uor72JeHAyDWrGkviha4wAQ(P)r5YQYxJT5O` zZtCS;{61J3)N4bphTb3hRP{+uv-DWlEu+iy=o_KS0zdlpm1T%%yi!aL&m%KjMUDyf zBet*BI!lOYqj9+LnkmS#*tXryb7i_LjE=$xm4=(zY5ML)nbqnzO>L%Dy1%(o$b6#Bod_Q zjPK|?)&Wi_kHedZe`X7~;rv=ri&uxVw@-lYMDJVP<=&rsUInWBwSFx@1iP>3=zzhOcp$3+HhYEgT9s-%-+m|5hG(KfM(H$$sgfoQsOI2?}6cegnIp|?RPt_+uuNF^pYPFQ{zpP5_^r}dZ? zP;JoO(}~-rHXrC%*>STqxAA+6M)9$BM}@I!Uqf|^yx~Ced__O`&DxCm#zx=fwTi3` zO)FDn)TgT|TK+Uet73GoR7cf2Of!sO8m{_=@|NzErdB!Gq$CXHCZp8+Qy-wy>z24y z0#cX8JqlP&taUa}$H8N8GxGq~gndBxtgCo6qc?J#|JLU`Cy}>9c24wze}ygNafIO# zqu`E!DXsJt^VrO6p*O!4b0cR(=X@@TwOBgM#k?Yn6HK8E0=r7@%Z7pu+C%}JzuEl_ ze90@t=LtR__vyjx{>=Tf2hIg9Cif#F9lAvy>nbvV?r=k>HUkWGys}7$1H>YQqT`WV zqBx;VQ9m%Jn?vP3?eU5q4a3Vf+0L5%8@H?KsI>~8<`3;Bl}^nCW2U9ZQs`P}v}yYo zI$Ht^zV>AM9O5?=O)S;r64T9xt+BvddmV`|2f8{DQNS^BlzI|5n>Pu0Xi$Pb@DZd| zC!*aHasVIuTIh}NG{*=ifRE62+IYrjXe0E_@1n5Ivd$IZUBo`@4u&{Fz$X$JX@h_^ z8BZ{g&ZKs6Lj@XSI|@kmQ8nmfiN<{kd@O#!n~e{pO?FhHX9b7R67a6m%x3#ZT%R?7 zPjA5o?FidFycaFTB(}=A3uz-PnCZHs$jPy1SaA(P9|xT2h^rAJT+~=Lwb=^}vs`G= zv==&a46kdi+IJb+o1R+;^%YeN8BZ&!`mcHz@)ej^x7+M;lu@UwznZpcX7jzZRSi$5 zIKfV8m?2U53hsgh7)N*X3Zxq!D<_CrF}IdW24a(In{6f0DTto7*cNP8fhj;W`xoP~ zd^&rh%%7gA0EAeSkG9bCk{2$0!nSJSm*;{L z==C3Zk}ZV&oW0Nq%(;xOE*2a1 znM3;}5VDS`26&xi^=19*|6F_z2p4}x^W4$QNwV#d7_|>OF@mCS(+=@#&REt1q{7h- zCI{Ri7ICkc*P_0@w_MZer9N}KX19v%pV%)!0dqrph5fky4fv>0?doIcio~HT%NwA( zeG#?G_612-%1mcFCAdd7^Bmp1Mle<>XWF(nN%3B7Ym34ZW0fFIb#2{!=YH4X#@+TZ zO^|#v_(NH#o~q2DdRpJ9`Y6xRAr8;7qIEk^?AVV;0k=kH_Z4(=rJ7GUASR8WH>I1r zgi=O#W1#DQ@XGd?UTwx-ire}!A#AEPY(@+QwVb69NTq*cz zYNB=v?i%?`m(44A*a0A`%T?l-DrU{*kS(sg1=G35n6K&Zg(k7 zm&h#qqjRnMt!1i2LA9wAXoq8tNr9G#7F*w8FWIBH9~GIj8KLK)iv&lw-S0oZE{bqL z*b#mTP$->egQZV>x0Tigk6<)$(#Zr}q~8(FQrj)c-O(ydL#i&UY03mH7+v zs&M{HpaD#Wy`T!r1-Kpki8$Fu;E`G*UV~Q}ys=7tTOAfECF~^=BDaBSl`F`T^hm1^ z{8?{;0%>_Qb2L`o0ey4F1Y^A;(XbaD)o|04&RTDML40aC2mVj6Lbs#tyda8eR-7`q zMeL5d;uk8b>!Vk^BZn%$_L(640Ey22#2!q3)b|K6>b4t3pxsV8=f%( zh5xk30tXsj+Fx+;X{hB3BT0~s!H$p6?8sqcl%Y{(!8YUXY)SZd-$44_rsYY4orR`L zF(awY*4NSqnZrSVg!5{MpHo3w?;c7{4Xy#f=9iLdK@#m@@C4GImxd74KH%QCY9(Su zC9MLvjw9UfTnnTKzPKZig@H%w{{u&Ld8u+L2I3O_TI)%3XEM{T%yhovis(Ca68x?# zgQKBPcZc&S^{C~Eum@e)lB-FMw}r%QyIg!<(ErN-5NWHIlcQ}S89#}qTbYXP*jI%E>K2AD!_VSj;2&|l^ykYEzoLY^XyAq%h7h+8`|Jxv3o&YKrXGCDDm&4 zOL5`8$;5uA9DT(4*@in-2=`mlDId4X28x!_R@Glc|3xl38R{;gD@4Aw1;&UZ%U)=k z5)pvVMcZHF2Ht9C4`(mabfAFQpLnFYEHD6cM=wU4s|kW#SwtaRg;lDb!QXhRt>dY6 z%)4qt&=ZW%o{-*hhr9dJGid8HBl)|aQ${*d$odDf*s5(ia2@;%KIm9Tp9XV;1Axa) zR5XqHhGat#w8;?1>lE!G*~%}r9uj^g0p|=s0{sMM1?>{~UGRjhGP-F?MW^5#a+hG1 zy;}5vb^-cfe*iDSi{MVQ*F+$6M%vGkMjWH^MP12OV`t%H*KBwt9^-Kw)!Tg&YjqU5dgD6sy|EkEB08rSi_`3__L=B(P-Z9~ma%Ff$^7&`LW{FR@kJ+~H z22naYihd?`x+7?pY5lC@0T|xn%7tD!qQOGJ1YhW3 zWPn>}JsDDhk9Kpm;4;HQ_Gi}$*#f(_@nhHl&0F_#jszYCj6vVRbD_uZSAd0PS~{S5 z<`+Hb>vpltr-KU}_l=q2{?IBzH;K+Q-uyAc*fJ7`uTw}1OyuK~ti zjBns&mxAof)Oj$e09pn-CYne;<|Jz&Q|RhUzpEQgz=+j#!KtL(Kq=#VLMS^`H;&$) z+gt4o-$NI=7~~LWDimq$!Fgt$&Nv2d;|Lvzt`ok4+D~x9X|Y5&4=}3*lc@JZE9(|8 z6F6-!h)-B=3KZ@nQ@mFg=`wd>*xl_OL|DbXWSNEnz&uBSkJP@~@|~Z+o@>|#-DFD` zXFTG8IYf;ehm9(@vl}ITVP1aqB!dmeI{-LFznF?Km{Rc%4&iZ-un(AowvP za2SBq2x=ce8v`!~*SME6gItkLvFMb0u=6GsVBHKY5m+@P0wXoh(Oc5P!nFqp734+r z5PUIbv^Je$Np$KU;1W9DJkk-t>t$j>34p=816cws*7v1d1N!S*@EqWiB?sVei>GnL}vv`NQ3aFy5 zHtpi3v89F-4hKJITIR^~wi^3T$3>)j88Cx3!BNV2hO6Ygks6NDe9xW`4rkA|~p(}{N;?^{es=qAtZEOpmHqrog-Fs%=i3{}ECXoDfC z`x$bK%A!AD_`0*9aj4ndMBIeFLG=^@(&5d}22km~3Y?@(bU9tm$u^n-vJ$7kM^Guy zL7Ya0(!Npah;7g_1b4L)B6u2P08+^YiVrqW+0=GwJvjw9=N=DCg7=bQ_Xkh)A91lE z6_rV>B)_|9Bmur4Szt5hBo9+*#D8En8bDkE9s&=^jqpCOH&qYZ1!hr4pq}70@)A@- z{e;J2EyzkxL(T&I={)!VrGutJ0pv#THnpqe<=l33f-jr1jDJF zv@+@ip&(m`5iYBvuj{vCfNPG;YVo(sF|RUxFbRz%dLQjtUA`_xGfB0#W4vmF`g_MV zO|^cdX02hm?Y3=HBY`yqu= z(_9;gIqoH{$)51k>ZG~-TxIrbhnF3)8*SH&Ps~#-HHO>z-yN>D{w+S9dgGb~dab+$ zu3cZVzB04;ZNZuE%krXgYqD2m!v(|2m$l0c$IZuG6JUayD?TB8;YCQ1fbn6OA&EXq zgmeK9TMPYl=IGC=leFV?S9GJzAvV&5&_5teXf!L9UnUys`^K*zpj(6_YEf82a8#hg z`?h!-PtUr_?8y{kmFRD@8`_m&XDmY1v?H`;`at*>{T0%kF$`&dg|t!d44~CH*nY#& zW^UAO(p4)J@^|ucP2a07l_eE{Kizrg$6l}79~g=bxr3Hg@Ve~NWKfd8f ztp1$7+u|D{W`w?q2n}8?%|VCKW;scHT65E{rC$zZeaPDWrtnkX4{h194m0tA1xT}e z8bj|#*`j`jal`jUtP8h?-3Xj3eT6TC7F#h@WK()gP1U}-t}V;Wjf__5sxY+6;+`>S zZ~OG@y)vbx^VOJ*;T?giy@v^8%pzJ(^0>XTy~cLh{s;KQ%He(F*#v8ZUc6U~OJuBl zi~-fYRh?9(sCOB8_OA9ybF86Bxw3hC`RKg9&zIahbg}HzkrSbp^S+LQr(TL0ClnjcgJ6c7ChJWIc}`Q(wqLHiaSJNsg> zx>Pnlx@Y{40kJcGENC4|be$aZG$7MAfwj%@t1Y2^NBNY(tehX8hP;tv-ObOe3pAdh zeUt^pmk+u=z9IcmKd`K=mZItRpN3J}qxxQ{gO-z-xrlo1IJh@}NxhGvJ zIvp&C<;VVsj)~kEe9}8ekj~nR1~MP85ArHR4WbOuQjt^a7V5a4FgBWu|KYvh%wis6 zbYs#vGuitY1L+aW7~Ia+2%n2Y;u^0cuV~S~+(4#`mQQePEygSAUM>Bq+(i%b65r%M zd~<1iCinc-#|4#}nFVpG{*5CZ{u402aeDla+Ne>2)6f(bs_ostuNe3@q)1ZCD`x%f zlK)Q;R@bIcftYY<6xg%d(EB4l59X#$4PPf%Li|z%HSMigRK=;m8>p6pii2u{rd(U9 zH(1F)z9XYW6lGt!0KhTABBkm}T8RrA_Qp(mf+4Q+5qK zMPA8lRT#cs?&vlPmCH2{L z61haP&KzrvwZ~h_Es*WDL*-san}TnaOb#-Hc?I%CLy-4&O6A`K*8&a4H5ZW2K{r!U z26?CZjPx1Ukt__{i*zw8RxIn#t3GJwn>M?LGw<^@@Z-2No=$LE9Oe_^zc>I4$ng8% zbsJv*rkcNWXj^hxCbs7&PikV!qnsvBq3XiZ4gey0OxP1J?HKnKC7c@1+9ura=8^{ZsuVW4-MfEa%RW zY?0oPG|0UCI{PR4^^(mOhOrMIXMim0N3Ed!b;}ZYf#QlL+4#VEiTF%ryOx>{D7~sz z7TnE_d)@seb`;!^e+I8yC`#^nvy&t7_SlmPe#~l0A^2wEq_fDhzHL|O?t)o42Y%(2 zy{q;uN6L!id}0Pq;I}<;OVaZ`0sXJ1jtL9kx49BJELEopZ{=GG&XhDXRJ8l+=h`|G zDpC#Rpfuhrak_M<&-8#*{yRlmpw9Y+mSeS?N@4B41~r=Cx2yBzwBn(`Ln3-E@5ByF z;y!h^nIp}M?eAS*U2A}sSd8$SXuDv$V7>5=_^!~B&wneq4KBE}P$3H6`Z@=^=gMOsTgN4e!0TIbP6X`wvnc$iqF?D^y zE0)_8<$w%_n~VNResaIx{tZ=iYS>dhsR`BBA&EZ!M6d4>l{~$BeR58OftTT$qj*-) z`g?Xht3Xv8TDPY?On1-n*nJ<|2zge6H+enqKJVM&d(Q`wz|8*EnQb2$4mWI4tZ;tD zBZ6DvNA#XDID63NUi&&R1NV!zptIcJ)~mMZZZWkUSPqx7Y~ronJH3y}%-&OcugeaJ z{<4~(JNDoD6t!3>Q~ggfM^|gkbL9hu25F~weGX}7P2~^ zJ%Q@U?#=DVNvIE@a}V2JwT-M?S)B5Bep#=Y(M@ODgEVnQxjl)j0oEf?ygj~m!V)4i zp>_Trys}WO^-5c4{n^Ixx={= ztahCN&NKH54oEioX#J1)Vcus&>sZ~X;ih@&?kb0BiQ&Gv!93FbA92W?>zZM+X+Ae> z`m^A>>eI2O-*0Tb7<&5J1@OZM_qQ%N{qLsqPCU?O;K;T8&qa^qWqHzX4KcyYkbf-} ze;V-aLyqXrno4C;CyR?)8)}YUowYo%tq0z+ZC;Q3-UMw8I_Ec7ybKSfKPC#zBlKD=Y8+&~ zZ@8u1r^YmOx}BQ14xi>7mFEf`ecJ!L{6_1AXBWG?xo;d4*MC?^zxkalo#!NHcP$TT z$In7)s)3e5yWmJ~U085E>)v~2epMx_IbMH>0S8@7tr?ire`Qkdpkw?XhLlRPy-*f5 zPijt4bXJx%+bRbZzyCw8y`dUJ{ANxUeG7ONZVzt@IwkIdJ)-`#Nvw8v4f6qKGrymp zi^wdQB(wVt3_2OKG1wKH7J50bL<}+W-5VU;>}LBRcN)D9CtIlXs+N>^wTbTY6F3oA z2BHH&BFmbgOK-bV_o^(kY*N+s+J`mq<)Xr-+$HZXJ{oZiy;yd0*v|y~X+L~Q;}A44 zE^23}kB<+R1(L*Q_X}VHxy3Z1?o;;c=W}1Gay#mZjS9wlpZ;-6diwX>nR+wojHDA5 zPFq3?rvD=6Sk9>@w`DdpS1&FZ^gHa&s5-GO6rRD`CIx-ly=$a)k)3D8{n=G4OyH2c z^!X{QWPXIMfp?Ju_#W{Ye^ppQ^u_3%p?=bX%qa3W@d4;YTg!OHoX8m~oFT3ejKja6 z8R$IrGl5=Y;%~)1!)6y#-`cXM5-I|IcFBM6^GiPQdFqp|m;0aCb9&m19zT2$|GruO z#Eft58QW!YEEH5H9E@#aUPD`GPo48s%_WDP&%c^>)0U+x8=*Bajs_MbZ0dEU_x)}M zJ#pFFKvqCaK)x>|FhGIEvn_+FSbv8X`Tc3FcqU(E$Z{(g26h&nf}Lgh<6i{~*+stz zL0iI?M0F3HAiazQQX=POw}Z~(Nn|m;YkZ45ypIT>kmaU3&<>ur^oMMwq)ZTlKS7=F zN3aIE#hB#j?~RwrWHExh=uooCdP*1B+EUr?&y)h6ubdaNuA4Fn_oVMTb;k8^F0~@X zGGbVFMcAsyUR@t2Fk>f#zYY1%pWtqHa+@Y(kGR^H`TWYR5584psw7t{V>Lg;J3erK z2pILb%bgyqLEPbE`gV@eV`}}>2CDRV@vP#&(p`0-%Deg{=AGtjbG)5J9Y_LKMuaoLk%!_G*@qyHtEc+*PoNQ-CM) zuJTK{;jCe3IkF#~4E2Fa!Cyczcn#VH(unTH-!1mPci#iI2A&$8k$5uX`qu1N`=hW) zajpDt$8chi$Pr;rLi_C-x~p&R$iK`{TF&2$m*P8q&!&9;QtfI=Yv^2KsU9!yWbTWy z{oP&m_kA<`+UWm=eNKHJGM#6j>tG!mPI?;(N? z(2FoE-k8iv-I?guIVUXCw~1fDxIyV$4#!K|9M@1V2hvir$ipCy*~~pC{47qCmdOyQ zg;&6QKre!FfP-)(l8W4=kAzM+Cm4q*&GlFRO#Lc*d-YM~E%n{_w-wcsS^CI7et*aa zqYw`E502ya*f{X>fQwyG=|ArH4tM#PoRGX#}+Vg5g70t_?m3OX`uBrxNz2sqg z<9fum#|Osd1>wG%{r821#mJ-hzOU)HalZU-T~qCW=BJ94nrQO^r_((b3`J5n{epX^ zBn@pDK5bBH&+DDrg5Syt#92a6_y?ar8|6A;T+^|(by!EM*$iwzH?be_e1zMDhjAed zC9eTA+A+p4>>Ij*HXXR&THwicAl37VCE3|eT(_^^RNejcYGr8_H7D@CcR#1REy-pR zcj=ta@7-uVXeyAi52W+@oEt1 z#pk?6dFKW@Ivaa5^(jvM5Y~aET85}|ls9zuteu>`j%lVOO^f2A>X+G_Ru^aIWzmT~A%c z8$|eqe^{;>bga4?_t(zEyn-c5a&nW z!h~hr6H*)E{|o*q+Q3n8-t*dpBSbN{j2xz$CeLrWqS$LMMQa7UBqt!8^F2< z89y2W4X0EG8UlWof6zWFeC~MnHZQDtzqv0&x6W-kSF_a8$Ln?UKS@2i-%h?Aoi6id zx3Ek3+r;a=8Q$UC!M6SNzltsw6;_OupHbIoE^FeI`{e&Mk5pFLPM{WvJD3~ZDLBOY zHFp5}si4VM91H|*k#*s%p!1xNvDMJeeVirW1+l#0%hX^n9-AdT5U?SF8GA5xUG)9% zPJt!f4}78n^MlkrtAr=9!SH9Yk$6u`C&FA}>jPbV$IR9@@~ZY`O}74{j;9_VKUsDl zCopU3>%5O|zqc1FYBw5a_MVEvfB*bR(sINb;}VlzC4@%X1JXp-@glxmO8M3XhpN7UBTm9z~vNQ~pvt6HAccCIls=BST3`=Ff!gS--j{kYAn73|5P!hq1o^e9uv z44Il&%O1hzbH)oW*(up>UMB6L+nyEtnx5FI(>`@qQ|P zz(2!n=c3%jtk+05{EZsy=w=$B&(p2ZEL63%ZEX0j;(lT6w}kgl7XRbvue);(|7O-E zs&%S>>e}L&t=AbPp~bNuqf^7~`c`=9#eF2de9i{;4$}Ge6y1iL<`p`-F3)_@w$<9M zi)<^dvsDIHW!JA!#Op7b4p~ASqudDW>2dI%`Ba6}higKOfjj*k`1bVc9`G?_apd8! zD(@&x4s$o&BH)NIVJ64JH(?skIINfzi96VP*xfv<_5!ZCo)f_kpHYHHp;hEcH;+hm zbhQ`S8|(`$g{qqRsItgkSHGsbSH0cw;p#_G_RK#|8m0B4e{*stR336<;cbE#F0xi;Mv4VEwnY%9yiTEtNbo=R`{<%ML&J?*+Z?pk29Y_@+@zs>Gfa}bmRHVN2W=y& zT=Sfj_6A~(hn*S))w&s^n~>S(Sob*#sdj3ydnloB%&{&ZL7xr22e&)}8RBS)C!y(!kN04JD1$UE9|+9%otcsFz(Tt>YmcDkxuLGBhZ4jcpDg{z_Q z&{mj%q@%N$2sW10g?$RMF@7Vv7}FViMi!%haT*Ds^@WarGDrp90KBLJ?oLFC<3Gzq z(|Xew({`iWkZxG1o1nhcp;GvEEbpKz!&U#Nnw4@Tu70K|((cq>Fm19vvYBnO?E@S? z9A_Oo=LOeUx19V!jv=>@2PpxFL;c_-v=4L!BMTjaiP$^w1Du)M@7x>Q_1wkW0o*y9 zd$$#zE1^m0 zQpAZQFp7{2WG_;V+(uR-SsoAensE|+$~=R;$4+^8;zQU|a53iw=LYvAFN+894srN+ zE-QwW#?rC&vG=hzu|f6})?UoY%wgi#FYGz%J*yY13nuW)1{5&LkuCIk7=@C61LP%Q zxl?8jwhl0_HrD8ibT;i2?L>`8Gg=F4-l_97WAxpOuxY$$wTWYXVEUh@&fn9LV1pfy zv)J)J#{q}Psc?!3*v%(nse9B_Y9uA6mV=Ywv9!Omar7tjP~k4ZGJDQ!snuhJg*0N@^v)MiI@tjp$iu;rMkh30t!ft23#lLv&6m!RMO}LqT znBAGZhgE?!VINpm*rV73SRKrG=1g=mqYF|=djx-h6ySPbFq!LeJE9%0?IkvbHOYM4 zIM1+3AD~~ZXBie4))?*^b{nD%WqM!3D&q&!L9@tGZW(3WV3k{g?Oz?BYqh6qKy#U# z8(q_hJoh|`1fr|v6Nv)I_TqQPoXcsZBj#YCYoJN=XS?ATZSdWgd3v_6ZMyLT5YX1R)^|# zXxC|-+7{g$eTn|Pp_?(>*xR_@xZPA`o@8~{Ry%e%(_Aq`s5{2PN}C1Fg1^!>(z_yJ z2Fcilno$(%#A;;C^yo2#oypd-!|)z>Z~VAt!r~}yVozgB*j5(cIZFthif6KKv8u2Y z*fC6t9bplyH1>5i>G`IFEy0K4oAH0}BK8FK9abS0hh1e3VZzKZ^eNg0^<}K1Pp18V zJHg$dI4}+9OPz90bMx8aTWY`nvxqm4C#mE@VO z7|NRG`7ak)39M1rCuRe40`>{Jg!y90%&!a)GK0p2yFd@Ya$qHu>#lZvb!@jCwZxhG znR*!G4Pw1b+o4I-2sQs|dUz&8P(7$0sJp2>p-t7r>U$afXB=wkFa???ntzzjS)?|= z(Zz+i50dw&i$Dmt2khhF<1VMA()ma`VqjFG$;=7NT;?TgoyVppva8sBI0Mi2=&=lc zh2O&8zL_KDUgOT?jo~ALrGi%;&TFV(GrxkjmWT3uxhN+RpTZ7g-C>?VzcX5p z&-4W}Z&(Lj0=|&Wv1MehH$UD!?;3~L%c>Q^So)xF_dkIbo zjtXW7(gkA#k%Ah25B_pqCYR#)a2nYCSjU+&8BSUm^bs()k2qi3UR#n(dvym@rHZ3% zM_U3KH`WcPb<~8_jHo(PQC=QcF`;5$#l8w|6;b`5Ho88)5s)+5w{$So0^R?NZI;=N zl>|vDffG2v7<>o3K#h0fu0%U%Ibvwj>{AYGFK*e|XslaX z)2qt6GPbgRB~js1v9Ua}Tv?u9o>NX%L{>koy;Fa&@nFm8w$mN^)LV7chC$|U)~=34 zt}E`H)I9JnB&3a|ry{)={n3A!zSwT;EjEp{iM<9#J^%eZ9)-X0_yT`USMF@ydEO6R zC;nDICZW#>y1N=pAml`JSpEje8>wJe~L GTl0U-6THL# literal 0 HcmV?d00001 diff --git a/docs/source/_static/kaldi-align/curiosity.wav b/docs/source/_static/kaldi-align/curiosity.wav new file mode 100644 index 0000000000000000000000000000000000000000..32d106d7b598c623dfa9f7fd0413a7b362b679ff GIT binary patch literal 22576 zcmWh!Wq8wU+fLFpb(ghUC!M>qF`U7W2X}dJcXtNE-QD5A-CleQ*kFupjJYxQtzF%c zG)cdHNB-nTdfZ2H_d2ipI%WLuAwzG4klT#M0+~6Ys004AdyZX6qb^rhb2!I81 zH_WBFo{0h6fB>K`un}kn`hX+Ao?tGJ1$=b%edbpe&^((i&*f826Y;)-(DacoQRyA*T1E=qYx}OG;1Z8TpueM2vCd;B$$CM6&ai zYn1?E2tJ=|01#&kF~>QLvK@QMd??7k0mDPkTkrrLhUAl#tYxhIP#Tm99EUzqPr%ox z$yAWar#z<;pE67L0(}Uwzy^$l{U*ft6>t+zpIG&VU!_SApve7`sb<4gX7>PBBp@Lhg{!@fGbfVD#VC^zv9)NSZFTnnSfG;tA6E7&C)LxWqClh{& z`jRn@BUBsE6MJkuNG=9%J6de7DOTzlVx0XsbeOUT&>>TyXT(eUd+<4R7@@_-QS|tI zq@EVe)}kyzO{sJ22d+}al6veMb(X_neNS|7cTwrK7leid+s9&Z+C;!iISfp}*x-Gn zC)o*E9R2YAB#egO_v{p~AKk3mru)JFD+n+^r)5bWL7?RLC%q<)dF+@EdTSzN~4`YD66VGIg+7PR zvoJEyKE{$^zX>NYxEN{a=lBL@fEm_{_N(Moaw^Iti=AtX#a%z48vX#vQvE;1i_}PP zk==%wtY*gyY8XXg|A}Nd{&ntwo?sqK3 zHe%ISU-XuvKe8E%Kt6$_^MHP}Ig{KDe8zg(Yn;i{MfUfG8tf`?3>{AFwNtSa$1NZg zpKT8^ra@-72Y$iAMVA7xV3<>7JE1vlx*%A^m}olD8Guun$uQe4Hvh0QoJr&$YcR4F zyn|n}2AVS+*~CsX+tvq)b$l`0H8t8qP9@ypJk))+TVZc=p2I!tUdC(aSLCi~hgo9> zfxSqx-NRU7u0zh?JIJdxp!=A;6KbH~xVL(>`Yh$R+dg`X?oQWw$8o40+)Qe5)OHn{ zX�DOVzp%af5!h=^%cLCc`fn)>=Lz1JG#mAbSy}#YUUY+WqX!Hj#ZI9%^5SNgaP3 zlPyE-%N=2$0WC4Buv9RexPyn-@~w2o564>aCjJinpQ}!m*_L8^s2Cy9@3H(#Qcy43 zLHk>)n{A|}%_guptoLmV#zKP~RXJ~AWtzbDaD)K!9f0m`w~o|6z3olik@_EY1?>|Y zjvUqN%-0>S$Y%RyWV)jQ@3043LhT^wjt{f%v&P%k+XdJk`zDLPsK)z%fWv4UV0~rT zYMN)8k3DpD*ow^zi`ur-5{3M>Ew`iQd$z4)5W3R5*JeY1*)`ZlumlM)tu-G&I!OyK z8lP(ZWZq)G==_B@Ba_T=wiU?#Y{f`_%x;}!`(i0@cwx23Jx99p4*tY?!DzG3bRHoa zZL93lkzynt+fLlK9o8ZSkQfae#hY!5jAf?&NU?3YvMSv1Ckz zn2=!0cJo;q%e>GOY(tTwWEc9yw5xlSVH;w1Wa0lHD^25*|fsG+z@E4XOV}UIn zt+c0EDoy3c6dVFxpleLa^qGcmq>Z?PH8>RLA?qgdL(^HyXtW6r$9H0*tY!XBBl0&NET)$06>k;aP&X}fdI-*U?8!87>GY0 z=Mi6s>rM;6cg_Wqz!kWK*x;lBT}~xPI7Sdg-~(Cac!_-_9{i+phnvx9Q#WfML4k zq}tFK+i|NiSU1%4(-@{3-f^?#X}eFi&>*rrFnbxG?o(aG+CcLTYaBA%g&m(Q3{$0v zi*ANgv~koK)GF#*=o-{TVZz5a>qMVKXIND<3iTu{jdM<96|d#pVSZ!S>4glEqv8qq z6}&j!VD=^WH{~pCEb}y5$6~U>n2FTuXmBLi!VJB1=gm~4CpN^a)_Q4AXxZH(yGQ7nwY@a8t>c?})tA)l zt*>gDUH`DEqq4eOUOfL#P2tG$H)YE!JF2JEeXl* znK6-hoOv7eqP~LaX+s%7OcC=2^A2knr;A79-($zIQ+W*GAmJ6k7v3#SGAoy!MLh`~ z2KP8$JG_wLHk-B00IB!2oNle|pm&XF>u%yVoT$yN$^FZ&39QYn=~*+Ys-`TtysK

cwxqg~`o67`bNJ#Kl<_saL$ z>|^$=4C)cJHEvYg;ONWY%R}4skBw4KH3LRTFM?cOQh0SryccdWqthJ^>!_n@U>iueI=l!<%wP6)O zC9KkbrrG+V=qE!`Q+-uY&1Up0!(+o^H^ohkzmfPO@l||41OsGAD9;66FDvBP>+Q@2lgJ;JGv(j_cLUd?;@`ZpDC(YDynCrY#A?%W_KD03rL}% z@J*No@1qZ8jAC{&cC&798~B&_THYMdU^k=tRpncc^>U-6OTg!?=FaE-kF$vNhnnTM zX&TX0)_A1qX;DcY|L2!1&oAdchJ3F0EpKwu$LS_^NLzZe3N$8cwMY`NtIwCd?_=8o zE~z#Lj0|gxFhDXm^r|D@$Xr?i{UeC_aRYipa? zEw#U(wh4k1$zI;RrvrP1u8H^@{XUi+w<|^#y(VHsI3AK6RO=&Cv z4Zb07IC>)~NH{(bASq_rW4e{Tn=z4{BCL_ERcPgpWsvlhxK_-OotB3vZp#`3|6>F| zvCi-KK~!qZ(~s)PY6Pnqic9m)WGlZO|8(KqnAeIo|NFMOqK~fGQRJ+2{ILzRRiIz! z<%)Gt!~4VzIG4nTZU`FS->Itb(J8%T$GI%(b28TP&9N5Cu%>9TYv&gI`L#Une5r5q zeqEaF9-59OV|^SJ0AVVZ@pssi1fmr0dJ@<;|yn=aFcYe6c>dHrU;ITPRkCvXL=TS+r0ce#z-LE zV`eHGLiGk~@pAipQ;;r54YfaOvDQDYH2=|j=Vg?;9OE3L@d*S;ra{H@3A}d6CHEhm*`6WFm+spXfa1J7 zN#X7O)P1hIK=E1{EF8`bh35endd_@Cw^1{@^J~YjHt*(vbq~wY-_Vb_pOqi8-pzi? ze0%F%c2-R3>5kpD_0DVxn?8rt!EWK!NEWJohRuuJ5xpmZ9=a-&7P8ns)8n>~#&}Kq zG_`k4>6p~%smbYD()P0cW!2&GedRMMJ8LqUes_f$;}DWKOwC~jif1d%dz|o?=RV0} zvy${&;N9+%sG94|^|;{LgDngp?J|@D(404%THroJp~OHr;9Mw`>IHV*7vTC!ZM&2QfkaA=tQa})V?+3nspO${e`)JSHU%aVxv~@Xni1~vvguhPU5c#>i^Q!i@1(U&vK^Ocz z1M>XozJrv_l9BA2z-8-4^|6*SjoX{fH=k=l>diH?{$8jyR1c}0-W1aLMyIg!LzKia zs+=cso9}taCq=c;r@z-drBFG-v%zb#_XDLx7AFv~THs=ul@k z6#7PeLo0;Mj0jF0|F3wfn^~cASG$jOAEvnDwp=EVJ`)cSuHjNxd*O7-tpA@%h)YlI6cgs_v@MTZ-j*op!&$V znRRf(qGno0NjGZ5Z7&>ypxev{A>el2{f4rSXNz*O$8N=Q`CR!Kw~bPz=n>DrKEe9Q zJj>k2WHP_gmGEE6DX0}XM;pkz!4dM03JwYni<(3XNiW%Zd53$aN2~i|`9$d{;ZEKo zb|zyKTuDiW8i0k)8luXv0}ZruELDaMZM82Y&Vv!JoAF0gh{&CA-^jT762+Q%jjd#u9;yu{of;<}IZex+ji(tH|~ zg^DyeD!-=KpE!--aWvEvjn+cMZ->27t+ez9ss z*P4z|%|K0fnSa5f9LaaXm*bx|eVLeB`IoOB10=8;g?FXvBs)Ysgc~IO9<%)Jg&0B) zgm?$11wQpF^OCq-<_)JNp^4f#&8an4s_C^i>o?T%>lV~}tIn-@T&1iz(^%e_WZH{G zfqC>!o>6A>9OE}N@Kn&-z_tFMD%#86bDQT;&%N%LWGgR^`2xO9Jwv%jiKO15uB4u& zET(*;dNUmC*Sr?~EWrk$x7aQ*%6ck}d7M*zcTbQ_7kIF1=>1$>Qv#ocPry%Txs(Cm zSF#rGz!=C^%Tq&N-5Tv6ol^f$_dw&;exq(j`R_j)a?`UN-{=|j8AR5Uk|`bi@f(~| z61}us9L4L;?&7edOH{|hcE=2jDF~YxxLsxPm?YiJX@D-;(>1onKh^817S@J01-B$N z#{E^4^(i`2X7Q5YL(Xt z*=hD$;4He_egsvKLdtb`9ep&-H3>SuJD)=g`U-X?4-r;LXS5;r~Qnp_)LcZYjWZ}ak+YYU7gfZ2=*^z+aPVvyqp(8lTM z*&H@GEKa(zE!>4 zN;vir4;|UaH1v(rm%T}z=uZv*HwuY-6t+09uPV}0;=WBbNA!~Sgmax^;GWF$2=&noJ~J>CHf0MdjwwJB5AnPW|}t_2#!vxhHEL z=zkFX05ck79;d%!SphZ3`iI>}c$_dR+}r!E#9uU8+N2oc(JYS^Ho=kjD9cE_vRmEp zy>WQe;^Hwy{mSiil6FE}rwQqjwhw5Z)7@$Jf`5uzmA8Dp`&4+Fy(g>6{T>G{41OGB z^n2#@T`^DQA-yb~Avnu^1|NZLI;WD|qzuS)@`)vAuzePig^#DiGkrN<*lU?9=x+4c z%!Awx;W+6?*>-7+*j*?U_zBzv0{(t3!1>0?WDaNErO&6%ay~17Ipsd)sJ(Fs;d{Kz(tVO}cdg11 z@HSwq>bBgUn+Ke*$C~=--gG@_d0c<97ON|27}_wUc4{?QDf?U9aHwmyWfhrD-R?q0 zDSa)s)oqfmD&%sQe`sLPaX*vKU2n|$sE@z*01uuUDIP17aM|=UpurL2Vz-96^oYC8 z@#H^_YJ05hJyJkUpq*zGan^D&SPIrR_75H|d@6Y*lenFhwMZ{XW26d6wQw!}CFcP9 zEbAI0ntF@Owo{D1)f3w(t&HX=_0K9QewTfB`=ZJC`C~-cxwdFyymneMv*vbVm|-t< zrF>jadDy*xBkr@rLg_D`8BtFYp2e02>6A(0=WHf?3t|96@u${gE!Jjia%)^wx4lYL zmQZ%I+SHiXaY1cT4^pq!tTSvy@4x}l3*JjqryAv1fI|NGPVp>K z?3N#J10|<;xl9DUK$}B(<$Q!ySXpLo^J1$CJ&2bPTks~7k8X5K0XD(yY=SqEpUZp5 z`@o+r%9Ku&@0Qod9!jQ)jiS@SLf%Q%8JYwrbmZ7in_ugObslY6`&U?2P>}p<^Y^vi zPG;8SMgPrgRW=W;$}gT$71BvMOq|n-bfsCU%@kClGq6qy2 z5xnWl8Neo+URSTK?pV<>tp?;@TyDht&LNDC%IFBJp(;U$~u0p+N9m z-W&M=ANPPWfuw)BO6&F7v(c;7N3UXe&vzlzCH@r-gL#*Bo8k^7Ij1{N`&MhP<(7*L zjI;*Y^ARgD85xbri7Ai~-p>qY_viHD_U2PW1(Iah4Oyb=oNpGK<}cugnMbL4fPqjr zn(S)h^De)}yo#ws|NS14FU(WraB{*5O@HS%a4LG|6M4I7 zNr=EpF$*Jy_@0vUWP{~{-8U(6+=fXmy5^QC4vkf7YSB;EEN+i%Rkl8F-dKOBnp2ro zHLM-P0unJk)P0acA;RX$X&EYq?OE3dulc6s}Ulg$YYZ+rbWPgL0Z1>I8##Hkr`%?!>J^_DHnrW9AKiN&Zp27^* zIckegB0Mgw%j#-(V4oAH`mql`^%|+*P?H~)y`~R3%eKR1@5w?ZNJaO1i8%Kx~_JO?VbxGfd+5%l4&;1G;ka5gQK@|Jv#3hh?gc>7TM zep|g=g&lSzIn*v*vp3`epQcY{++n7%Z*v#(J%l$z^&*upnOn;gFguvrnccJ<;8DE5 z_QrHkr_wC#_}a9kI=5&+esunp{FAxj!m3JY{fv^5?~bgb@{e6z7Q1b{b%2XM-(`G_ zqnxMieAP&ghl1-YzF?*z#j}q`iu-59XDNes03J#0q!dGEoR#PZ^D3>bqrG`yV|Vj` zj^&+Gn@|4rtGHP)qo%$oyYof2O2^Y3(|H;8S)UO;G$Csnx17IN2)Me*b7`((m-kQK zJwBfmSH+Knut+PA@f*1DoHlrja}UOIoFSGH>#+pePfNLVy8X9pu>CStP(HryXW^<0>{I@4zt&>?4+COQ7|vSXIISSDNL)#y$W&SS@kDwJa1+dgx= z9(Y}Ke=q1opFlh8n!4&J*Ga&hsaw^4v)<=#qpObhH+`#Bm9Hv#RKzMvt)9@B*}AL! zRmX!aVYk4TfUTj<;Or4Zil$5QrD1M86-T`S{QUz*`JMFasYsN^E0()$6oUde_XWcN z(uoZGAel#)&@|f#>owaJml|-|J{owcBj6mKr;R~7rWv^=~pIe*aa_~MIIi<|Cu1Zt$Z9mYz_F)W6X%6Tg;l0(X$UWMLq zK5?o={y9MlgSP}Yyw|&XxkXCDL`}R#&KVX5{^8P7PTF6))YY}tX_oocG8@}2vE8?w zz@`J6ssF%J;p6mk%q(_1H<|xR*eyyF?&S%%pSkn6kC-ybByyZXjf^x&y05m)Y>2Mx zTR1TPZtkd`WnbQ?Vp?%u_{bKD+qYS-@WfKX&I_FoYf!pFWAYx|B;ZUZp z+53)PbifqVBTuJ`|6t0uGoyN@IuhEHQ{OhdLcL$fufATiKX-EWwycTYmS!}38vS`$-s!)u z8*|Eg71Wd#sy)F6yrGg!3a&ECV~w&!btd$3Ttv_1J-883ff+tNiW=SvdXfuAbkG>9 zd#kweNKsVbw4y1+--|6phyG0Z-M=8dm|o@8=-C-=7>0Zwt|ktCco}1NG1bR%>n<0_y(8w zp7Ffv@yO$cR}bGO-c#Hq;_*yh9MK+Wz0kT|eXrZI^M4IJD|;5}i!PT+s~$8=>^x-f zwY^3Bkp!g9(G#Mx28y)utsWm0XT+m8p{x+@a={+Kb)JGPWgLaCQbqx*NGnn5SMyqVgM?D)3AeYh&EkbFteeM=6+M(B zORq{UN>)e?iU#t%xh?EVtTfg)mJjPT{Ut?7W}>y$OU7W`uTZ56xzSyfn=wX3zbpBnbrJIH@#<$|Nh|q!MOpeyek#H zazMUa<|lb0((+-BAN&JYN*)3S(hf2A^B#+L%2MQSV3h-Pmv^uXVy6L%m>?jYtK~)YlQVj8s~Oob-Qav zcZV6Wu!T&eTef#ZaAHh-kH_($ag8zSB1?i=y}K1lWvFlo{T#l*w#@zVt7>>Q-51$)J+ZqJIl!<$^!QR?--gdaOwrsiN_lJ=TY(#rE@S zQ~8I~Mqcv&raUBA?VvV#7V>`Ys7z^`*Id~6sv}r;(%5O}Z8~5|M_PeML7*2Bl>GtE2OR(A1BBb_hVXGNh3FJE!U(y$L9cl?um zdU<&&jWST2Et)0-+3`@N!w;Qcdu*CvEg-hiZqoA^bY8S9-6PNaga=1u3A`Pe5dJ)3 zU1U{w-(b{7qg>+gS>8vykGG0-i5^caB^%KXHXrLb%NI+H{Uv$~?P~)J?V8HAJhpqgTMEBazi6FhScFfYWwGxIkGhq5{84aZC*7W^u7}D(+f-ZKW^>A%hs^WUj<)3X zbJ}0_Mbuhut|-sdjjH&yl4)*AH>!J#uP!7fWJ7?visAjt{kn9W*i9bbRjJl^;2 z%a@l~IYlYj$=q*&T|GYa>4+Wg%ao>y&xk5`gTUVQ5te30AimtNzwY;M=Z~_V+e@R` z5|Q7`DEVIh+NjM5$+3yP)x0-gF|o<{hP;ivwC%SIv92>-)M>O6^sspnF^F$frA5WY zzKFsjc7`|kkCfeGJ%cmp#~E~1GcQ|4_q^x*)rTIiEBJEA@St^py#o*X`71{W3*eJv zp>>k}Y}d}Nt6fg*7HbgM3S9vkNfUkuO}8X*Tdppa&;A?$&hoYsUx_{-z>d7r^-3rtVYULbMObT9> zAndh0a)WrT<6jNXmR_G(QBpLmytHMoWipXRD`uY)%B5P5u%HQH?x9b7w|l2~{ZUql zX22Ifz{U1_1b5Je3fFlZ2)Y@uJC>5z9oG@k%k!4_KjvuimvORYuo~31Xe)I0jh)u% zSh`~j@PXD&52OYFS$LKW(++R;EeC(evtE3;{Rw%IfA8oG?~J`5ZTyA>L&omv)iNSt z&Zj9CqF`)PtykW@9QOCXU(MfaFQ+__fAB2Ms(;n&*Nk?VDX#+LgQ*ioPjnylIQEc# zo>bvXtr?v@^|AQLqtEk8`!|NSj_Mq18c*uzHSBeaF#0}LzG%C0Pr$-}I`IkoLl@Mz zuZmT4FMC?%_umz*zZ`PTeR;QULl73ZJ5k*mi|OzCRd!0)EZELI4Xtz35DO`bs6Szt z_e(a@v(dEw(fRpiOh5I4F7rFLtZDRjPEmN?%

;~ZwDl&TT=L#8K=4A|`!4qxq>@DqRf z`i}OA{makNvgSlhiQSEURQz7nE{bNqp~>NWOe?oeFqJEH-qj_y{nHd%<6rzRw<2QTU08d2c>f|tac4xv~SQizTN4#O{4vjp>`lRdU6i#(6} z2z-~So_Ic&{}gsGXF8XfMH*#it9q0%6TJ)dWmig8Do=Q_6nBK<7#8%4=0)?%`ip;W zlvY_`yvFUd+9XJWFn!9O*yp=%oC9D~FGr-FJat;81!}&8N%}>8&4Me$u2* zc`JVZF6Vx$7iM-eVD7{bcCIQj@!25u!2L-?1jR4NO-M;Pg6Bi&G46226ZirA#HF+xBsQ7HssHIb&{5c=tV%BJSLNQ;ADPK^x_gFT zJ*3HZM;sb8Gr3>HTzP-i7v};iu6fe^9teI_#tAP@J?&q&-KPG*!eHi!l z)1%?H898OtE5im%`Zh$EGIC1VvWe3^Mg_8^#2gw5N`0h=k^qtQ~vYA&?l-VA+J_{3;6@r?Kb?6>F4!7zjB`w zZW{$roUe>jN+({24!3&fi#7gTx7(I946glHH?R4G_6UBTf57)t?C<_P{+ThRVpvUL zj{haeZR#-7%$B;kF)cNQE~k;bQg&P=3HuOL6WJ&1wSS#*pzM}_%E9R;D7%RT=rW`k zZN@-6mAnY`VJzWg36Ba3Iji6($KGyv{rzH9PDr}X%hbp5cL&@YcIoB|Uu{&<7I-%gNI9>`UfVCACWOhHokSx;b?_NG&R!@(EOk>txDYH zj%;I(bO(c`Me`DKdaC00hc&3ixQ`V-;@Mf(ARBhi*4KW+w#oL{egnPin(yPddHiDT zd*%#EjN`5GaqIU|#*bqkdOjU;^T63#C#<)e^{F8*CjFSYYuua#4>o;WpE6b@TcI0J zJ~RK(k7*g_G8*1~%y?EZP_qdwKz9=JoUd5!KADL#hm0I*9Q3A7aKfh0tHLMUmR$Aw z-!HmfHhlP!B`CbqIFRs{x%*8B`s(+>`?F$%^a@7~6<}{I5~E%}!Z=Yoqt(zrZSho( zHdySxz;>Y{=xNf2LHmae92(y*E)Or?&a&5zu)Y}2+0dc`S2kIL7-JG1Y6DNkSi+V4Z)pVH1{ z$Y%5?z5|@ff9pRvQ8%FPz>Wc#eIF$TN8FUWGp#DVo1hkByUvAh+56 zmcW1fJA6JUd9uUowNMQnWt*ZOtL?2X(f035YvXlJ)Lt?5$F5TSWh=sxdV3H4IGi^; zs&}eyJ|`8EtMAqwt>!ev8+X8;WTSi!M3g1Y>$AOARKnQE-~MUJ-4X`-A$SJ6VDi@V z?j+Rz8BU_{&{lXF)4-t%ma#-Yo`uoYU6lQ8?1z}=wrko$NxN+4I%@ibES%;xw(;_$dha3T-BqKD0t>41&iyEOyXM{h@jVXMa}Hi6^0B)0Zkx0CjTQx1VNP7 zhMTpV0!rrPw>KUwxOV&yxJ!D&-hMQ~Ia)dFpTUQw2Cv?~vUqrmQs$W2>D_v$rJ`X$ z!JhXWZ&W|LE92`o*IuccpuPhqM`VmxIh{8xV#1^`*M`Q$FBZ(}vgd98_~`w+%#j7d zi)%_>Hbx-Nd3o+>%D(OcJZ}Wt3(t?x1UGr^by{KV#*{kBn%1KQFHQ&qwi5;|L`ZB7Q(--Nl@u>+c0dlpUw?1jUIPaL| zGAC~Yb9ImZdYAS+Kd5NZrK#*e>-`!b&=}HT?O=7kZ<d7juPGG-`8y%Cli!f9=GqUi>=A+ct=zS5`8{}Kj`6^bg5@nMAa%Z7>nr>LvwuT86XgOASr$$uMSbe*`yR+VbiF<@U=+(1d;eZhX>U!n}-iG_F#3T%0XgEjcZ9bDQr*NRx#yr+{f@b+b&21$0mPaC#ZDn-?rPDz->A%T4b8 z^XTOvmWN9R3u>8zX+41t_!fJo(X1=cjn@k{+cJ-O?O`n~3~x~na`v8jAh z&W?9gPlvtOkpA)yuf2l``VEWT9=kVYTXd%1WWhtqXDrPcs=uUp*saxMbP5~U^(hT^ zYCCEtw=J?%0$R>|<<6jK5iJoDLhq|aNfYQIM^F1~TZ#oV1{#@`LB^Ygo8~$887PQ_ z0CkN0q6l}3lJ6z)`sJCTEb_SEaYQlHZM8I7=JHZWeWf3SF9n-L0n#F>vGQ~T$ZrLG;L$ICO!qhXYa9`Rv%3<<03fVSUe5_Y3 zk4;9yE7Mt1g?^hM)IJUUYiTkU8tSZkT#G#I9$aUt7FFG@Xev8Xl2i^?N4E8{+;O@4 z$JnBgNn{_;=;9n}m?NoaP$8`m9_p+#ZP6~(uyj3*<1GoO9_Ipn)Na;Nb`33tnn}@6 z4^s`arHm4IDdQ%yl~v15V!Q^qj)mAre6q6uxAyG4VsUJku!($fh}cq)7@#~0Dq^K^A0eKlICK{)ULVmEK*<-S{cTiPF-6`XGLdB z=icT;f6tUHEu2{HR~`0uZT0M`in`q`;T_w$dg=z7ew(tbgB|mTQ+OE?hc)1{h(CBc zGS5c0aE-ropl*UTLz`m^L;C}FX@i-sSv@%c+!(H!o521bEdo4GE_5yehFq?gtZ0^Bb9p7l zxn;Y(mt7WTbFGXcu)9m^oIw3fje!~T4AwKlkgOQGMS1ye|3j^iSF0wK{Q)>z&3_VVV0w&`t>))URcTfQ|fYu?^++2sj3 z(56>gH7ac{eX71uKS_Vs*lt>HIc;8Q5*mv&ztyYMHR@yPCAuT#O3Ns_*rg!u0X8|) zNQ|5U{O5e>d`R9W6Up7qKHvwJmK=a%sL{~~xVcgua+io&WD z6T>a%d}gm>q3k7sHmO4XQ{LZWu~MYS6E71u1!DvoxH~vWydV4#yrtYByvsZlFP9%J zvAFbbp@1eNgqMXq1gm&Fq9*dgml z^Ih{T>vH=ZTR+nW!*_EedWg6VzM!0kE`wGGg4U9?$W4pF@IlkRyHlU0&r~1jl6A{< zS9Ixy?WQZ{!KgdF^Z)%twro?PK1933Rrv$j+}l34&rzpo)^^|1mFiyTRIX>$ra#6H zhIz(YX0|2T{L+$Y&9w}(YaP>phptTB0qPfe1nV%fhLbDU$s<|r%%}8&pqeOj)endE zmbsMuh<%VVhqqajBJblFs{AU8l-`ma758=hR~nfA(fctca;NfF@UIJCp-nJFSRz=- zpCHKMZ{#L%SgeQi`EU_!HgzXBiFoeHYt4sZs4pQC^n-E)?qZJN^yI!_7t>&>KTSnj z58Wb{Ix3N^*kjTc7)b^YRx-z>7ba1Tk`o*~h@SXd#0@hzjv(QdewL}m%i5tXpV9Je zj^5~U;*b_^M+J&l9-HTzt@;VN|8-NkEt*B@f3@dKa`QdQ1N&%ug*^qkKwKc-IMbcC zi5X-D7)s?3Z1kFSuw|fan%&}Jk`V_8e54StkoB8Iux2uI7*u8py^6-5zN6lN4modw zg^Z^zWM4?5GlTgQftKHoFXZp#FW?n%R#D z1;s+6U=M!-pTUpe&SNZQB=ByqcQbHWFm%DW0w|<#Xk+L*X=U_(nWtGKJ&zthsRUA- zVZ=B5l;b60L#Lbm(~UQeHo2hG^3rZH_q9myE!GqE1@=?6Qqyy@zct%DNqt#+)yOrT z)ZH<^u^n@Wkvq0)f^>6mSUsPwvw#1jj=Es??53L z=~6essgsCO$309$v661cJL_xcKAD7;;6~>|=YHTZy_va-vyXm--h*C9RXA%vZ^{^O zfpd-HKd=OP11+MDpnZn=(h=$*)+oN6Wn`Y?Dj91zuUH#l9_1!&Aw8Y8-lf?NpcjxW z#Bb&rm`y*z_`^1_>KPxP^So~6cH}&^kow-4g}gTfyHXw_4R_3blu0xYDKw>1|1#!S zjk;^pPIA4$YIQSz#!eAWHJDuqud%!{Un9@bhuFVc8=NKh1roIcm~Y`lnj6PJf0Ak1$y1&5?o`#By4^Z~*xVSjO!|O99@~ zQb?hb&e2f!IL3-o92mKtGmJLFF^)m9LZMW4o!FnEpanAvXtnq*@Gj*O>gCKOPJsdV zQSc9P*!h~;ilw>~wcU=d@Kq$tat&PJ>`Bt_ao}q-o3hbL!{Wd{_8vqkY_~mv_F_S% zRPZvMV=lpBod?Xn3~W2ifLONLhnP#zel{oWa3z*ff#u{*!vt5CuQ2wL}5Vjr=+Xf&|X znd`g;M4(rJ>BI`a$EB{L{G+dS?ImXteLy$gFj*Fg zo!I`4yi3#6sxT_x*LgzI00{K6P%|la`Ng83TC>$!%wTl81D*uex{v6C4{e|B98JyF z9CnQ0pUp8uhZV!ZDbrw1kpr_-3LAXhFK18BPn5NB`!)k=IlQy3Kao(EtR3gN!e2 zYpLU1KC)0U1PR1TY>`d~@dXly+4g+1p2+T=hW5Ab)ZHiUk{N1GVn4K9yTCS`=B5Y9 zLDbytgRZ_4*YD6weQ2{kXFljWDM)I{)~APdn?{)a>GiU{jo{JHnhl8Mz63| zcO`OjZCi)z2AgJ$60?rDG1>Mh{W}L)Ggqmra9UZJTRE0}fCFo=3 zYKoBY6CX<5?5Lu9U<(NnokRO+IZ0U1bIQJgO&lH?iI>|v^d}V?DGSI&J`0h3hJHXK zJph4NJ;-oUChuuH)Y{iO*zw6U(v@9qVpM}_^J$DB7~0uHtmgI5?*ab6VWtw`AoX;| zPe)Ift(guqP_iv-mwtpIl??^-CzK9-oaHK@u&lu9ZNqH)nC|c@$iw`a$iY^da&0@I z%~+>n2sQ`&Pj}sM9GQ<zzf(!t2s40&A<`TwDZk=^J+<-lX%cy2^ zH*KX$H_W20f}6*e&1CH_%Gl zAMs(lLMVr3<_xQ8gP zJZD@s9=6e38SPvfgV^CpBCY}xC?a$y-h=T6kG4%0Xlx(RN$^DIE71jZ;{Q2MkO{;c z>qy=yV!r)5QwD16JL!GF1K0s)759SfBXpSAN&98(=8Xkn9H$u%2^RPfzU1`e=FzC& zHXdjhi_c_zb)K@!W^QJR?7l!Z@rkTL%v3+qKa6Vj934s9!hB{lkrrqwvVuN=xPc6% z+bBVn6j~%sv4%hs$xs47dN7jME{h>Q&lpNu4c6Jqq0{g~(|X!vKw@5nk0$E$?t)ai z!Wzb2iA*sx;_o@V+HSF^_L-IfKGkv_KSMDSQO-5UM~28y16o{raour@a=}?n$BjaK zAAPuG2g`|l!(mo9E#35g4V?F16Zap-Z{+TBxyzmg5|*MMprYceS}k?dy|r$wRqGyY z#Zk4`s9+w|JpU?aKdcB^n2iRVC zW4Mk!LCf3;z%jubS8Zq(Z>@hi@QE3O)PYz=JiEaCJNO;C7#kNT1O3Q3d@NOgO4(lb zQSL1N6lx!4q_#j-cv3hW-V5ij#@a7KA2?;9Yfu-ot(Bz;ynCMRp^k6J%yNvw`6o+See#I-o|F{!9BK|P@14bD6 zwi9I6h7LJz(KE8!`!rO|g1o7nH3GH57F+}uhT*oS5~q*tB>8w~MNki)W>z`2$;?5@I|0wdMmcYgfAg={ z`mi!lhwl~MPEo-L=o>aOc$Rr02rw_?42T&$ zzz?AoCgm4Cea=dv?u64gx|YXd_Pv= zE+y*tS8azw$M^;y6THECMZCiLL7fPLJc*lf|iH3}+tWb#hk2-I;~# zAp8|El{~=g1Iw5we-W98<{)!9SNamy0(FeS-$S^dCr-6R}%Ly0v9Be^V zEP%BGJw;8$_7d;$B;*%{S+E6?fzydA>x?jQ}QMF0OSv}82bjsRT5_n%R}udzWHgmDkh_sJ1zk!vliUbx3#El00?R^! z=~s7Q=NQLGS8uP2y3V-6KF`C%{Uc_H^|I-a6O>TbyNUj!Q14->70CsOmuMp3YE-wF zOHpBEm9md~q2vYcU*^-#kr zRpsi8y0I0Giv6F{KJPDjUb><+sbz)soaUhpwOzH2G`6%)3);XN?A3zVoOm{u*n!Ps zb!E3P*RanB3qNDm@?;YtFDf5K=PCNqN&i$uTK9WBGZNC`)4EYH&mxK>W#mm( zC$XK76HeAqMsDb=Z-=LcTkV`+SK2e&k<@6<3n$Cj&w11t?JTjkwA*b1TVq>C7~?gW zHM6R=Rph=)ecAQ7s*GqhX-<6}S@EuRQ*%e-{YI7XiuVGVjZY+D+9O$pt{^NTySSIA zpTw(p71JbJPdtLQgBwCNs>Vj<#1FZ=Ew}rhmL6 z*BR%Dq;4~^i6fjU_B!@2yor1R{{(-BpemwETzl;JXkFCn$b4y@ct1ZzSj^u-XR#jf zHnC?jnBh1oH`E^#0JY)Sz&o%Im`?TaC_IxqKl=)UH>i9{<$vv3YtL-^&oo7CRAtoH zeA!#^SMlnXzI#cz$`WHW_~qsY*88mT^XenDn`>?u%z@eECaUKo*!7}l-3i@y-6D;x^{{p|1mni5Zc@;59(IKNGYo+f6q;P~dQEV41Zr*JL*@ z(0tn{QKdKKw{3Nt@s-2Ba2bMq9D?(ic+R^dxGg=PWW_C{NxByiXC*4(97&3FnM^Am z7r9$j$PW_pm;=y(@L+}?`;(nRUdGJuIrt#BF+2dGFgGigvz2$7^MyRbJPaG5e_=i} zm)-@J0;#?^9S2PhG=jQ{iu)z4#d}`3AE>kM-}vdJ_f1aG?FUeS1TXOXyty!*G%0Z`bV*;#T7AVDjp}(?4lGi0pQ4keY z=oNDm8=@m(e~!1sm&X(+3nTH!HIkLwG2|hxpP$Y7fmn#YMEfv)25X=v@Hg;0NXyv5 z{2ptB3IIo_ID9MgbMSh256v!s!F=E8&b{sD%x4W^tyg`pcKv6(^hEJL&*m2vy>F=) z`myG9$-9pqK9(ro9V~+z-x|uzqwGC=FqQ5NdY8f<$R>_h7^ip|Hz;tTp7A=a_~XE*n|aCTmXX zqt;&5Af1}aZR2+QYEQQhw?A}3?qpYz+vl0>jqo0K&-c9cK;BNzBA>_C#dp{HfzEF% z@}>Ke`~w4@DRn^WpB0!xg@a9@^6(H~55!`OV@zbokvXUZi)7LHe|A4wkE`dAg7<<> z!5yJXbVYJpS|y2;q)I-DCJXQJNq#=>PhLECJI6;diCpF&EDkMVoMt?LufYZwfEPe3 zz)7GCtOQ?!^T9LVI}nE$U;}VATpCW9>Es-P9Q=s>`?)zjo!+^Kd#&YSi<9an8u z>wm5CR-fgqWudv!gqXS-UmE5Zta`vut?zGGZfr2-nRc3=(dz8;)++1E_HQ~!dvE7Z z*K!Z!EAeloMulkA9vlpRVzi;X@RzJnlE+@=> zL@p%@iG{2P)>Kv;>mBnveV?CW%!kK866gUq6MPOtgBqX-PyuP+5ojPRMoN(jXc<$< znolUmC+tV`?|#i|;Co0nP2TeGYhTfIHxJmt1}A$sPqg$@A|pdd7Yu@zZ>ZDj7FHPRd+mmE&!lSa}* z-X*7!{mE;@BGy+d3l8J)Obp9I`Y|TKtUIojb0v^sEE3gZm35SXpqHUsl(OJ=S(F2i4tQ2n)PZwK7Oi{d0 zEf^-4z`x9E=f2};*{?}Cp=GjBkns$hA6`Y>r(16l+)tg|?Ni$~x8|F68|=Ck?PP69 z%jK4(Ewft&wA^f-(&E>?(JM^TESYTsYy<3Nj;}jsx)E>C_lX)7b^#h_G{eqViR2=D z^bDGS6=O4)2Wa+o9rg?zhYm%AjF~V4c0m)NJkUV*+sq924Q&i{4QGVez;vJ;mGs)#C%KpSX;?0?D3pW-0gh5;GA%QsK0oKgd;sET_t-gyWeG)e4X4J z`7|aq))bN2((sW@0_ZN|cjz(53gm{zQFDDs-kt7i zof(c>woz71tHu1MX@?Ov{%y$CXX*~1lCiv9?~Haw`_i=@`KPu7W*~J%xh*jzHKinPd$5v7rXBX-Mr%6=1{6x07Wn>hp3LX*H+Kmn|#Zuu_;HhF(`9`8tOUt?}FY%)?tp|QK}MH8p-YQtLf zrDjCe-h$J_NV&1v6w&Ioo$fs1{hMk8OKJAKFJ{Df+&E!x$r0HWd7j)M<4cc7CQ1t= z|B3gAb_nIX61oNG93%mH1UJyFUJQ4JV|2UBa>vLqjxt;|yfsFc&Kj2+lT8OL$E+P4 zgF4@OiUTBYg`r|)^E)M7<+Bv|F(czk6Ec%d_FA1fvoDZZ)w`_Mpp@}FhbJuV@qPDO zaTUr_5o@~SNfrpF@nK#oCxg7s?2p#dZWqPSL64yvxESgWt^hAX+2GxfIAjg)qFA0+ zv=E>aJH zEY=UaWcC*#AUF~^Nl`$zHOy12Q{b|hBE1BmQ|UuFrL=Q&nC62%26|8p{sh-x+qAai z)?`z(4sRK+ovQt@`Ek>>mLIg)+PV6l&2#96obHa9o(y;uIflE3`$mw~WgDGJ79_n+ zUYv3!>Fb_oztvi9e4rO-r)Uxy?pDW_O|Sf;c2DiS%GqW9Pnp%e z+ODdm7M3Nit=c@(e9K-Q><9hLm_i)m3u*4>9)G`leXJw4Am(OlLM#@wPd-x4SFTo` zmi7=_q#f>S2nB2oGJ?H)nfCpy|Cu0Ts)4QhQJdWoqq(5cs0!3int8g7ddi5jmRggY z}J8hh#qloy01@MpO~IFE|HaFNP3VY?+GR(bw3|>Eh;wBAfF(gDpQG$ z@Yb=TiDh)l00%vYY=`*a)j>G8g}NU&7Wl_s;(g^7x!<_d&R!iG+E$o0>#sGxXsoC| zU$dZge0_4m$NC-B{mM_4kF6cvqA{;D7a5-Fzcs(GG5tlrV#JG`!%i~m$U?d^V2E-{ z)I8;%%JDSsQKh^ZvofY!aZ$QJw3a`QbCOm1eOK&)4P+|JJgfxnc=FmSGMhFvsm}E ze`tGSs?lXNoodWh&2Mxx{P`uNCbQaBCsEnHC@WQ!gH?0&3r!xqPPf5$ux%k-xH6em zVZz`zNQ}sdoq~zdB-shcSCV+wkOznSfyq}D{D$uqfI|)rZ(o+9;q5vF}Av}W>e$G+GADdr;XJGjYo|h z{W9&LmQkk9Z3i7M+;f9Jf>VI_&>?UiYpLhe$IK#I*0Ga0oJcL7vF{`Ko|9ax)@3iyQw?$Jv_)&;`+^{cIn-x zy^+2?zAX1r@6%vcHNs4#LiP)=jMvYYlx}1>wB~!I9u~&Fg@RY~kp5%c1n}V@|(Hs~b#jK>K9v!n4O@}XsUk2CFMET%wVQ`P%;yz8! z)t7xU{ImSO`*-=TQ@>GUI1lvD{jTTm&8*#I9#1LymELLFrAz7FmR*ryMM~6{s3Ij( zu_6MC$dRw`7+7)`<$C|@m`#Z_vWmQTeRps1p?VHe|>Cn$^j7>0)asP%keq? zm7@>{3_^w^Z9A~-#(&I51-hrL<8UWq=(2j9<`;#vE4?+T>1n%g*xQh3%mXtmpG>Jn zK=)zlEE38vQnw>#Ib&=(*lAQhekyz5_+T^CIovFjZ;h(w}wTfC-wWM`sY9gS12o019#oIq;b$na3=l-XsiS-=aE!d}xr1 zV`rFGgToW$hHed2U8fR_x2qyXw~QWC#wa%rs+!viB(>pI6mcPI1)d@CA(*Wy9^b6hq- zM_54`B-tr+#u~0e@R0kEiDLZ5e9uyH+Bln7an$L=X537a6V3r{px+TU5D>LK{IdbguXwT5CUQ$%M>x)ItpC^^)bfp-6aZEzn9MZ|Y zsb`GSeviGjQ#Gh`*iHLhL%8y=oH*1yVpFB54~=EWJ;yx2g}`dk_l%>gKiScozd89H zC%6QjFaH&93uhPYcfu6jMI(zrNr==Y$`q6d>xFsZSE7E=ZQ%vp36C&(FTn#Fi;rd? z`DaBG(H_A%uADWSv7cT|3CH~kW7$&lYI$*2Oe3R;^zpzOR*6?Jy>MOry{AQuB(MY( zi{9>fV84QH6`u{A8-f%aB)&!5aXqtCs?PV1bx1nz4EKyxD%^t~hf?*Sut&6Me3XbJ zej{c`^Q9C|H#eA-NS{N?Ch;(-@MG>=xQo;xi19J{;r(9uo%2r!hzw*+j|-&w{wqAm zIDlV?B4VBrf2S2P8(7mg*Et6`FPU`O8Dcu>gIj7TSD)`zmRp{i{{Hi?3r`47s{hI< z!hW9DJ(kR{NeahabrDUlzm{ z-Ys?4Lxz6?XSp;}``~emA=$?HpJvwjsW|oI3&;#>v8t{stNBB1YvZHFuDTc1q{dZ) z+s&IXry0Q>$AqC?yM1o?6o?Z%b~8>=_YwyYOh>h?+`i9c2i#Pn$3;=O_>reW>>=JE zxFCoS?G?pxX^fZHL_`*HE1`xK%u;jQ9wxSf-bdPv*$Dg1cEdm)Uf-Dfa?d?xPTJ+c zJFc?k;T}f~6k+-7V$f@X(_-%|eloXQCh>}74#Rh8EMJ&)Y4w)Y2W`@(y|wCw{;uP? zTfjnkl6bD4An0?@HQ#mOWt<&!7bTidhs<$=m=4=!x%y#u@e zobYllG-^$`W^WSIBJDpw2Wn#9>oIKIMU#50T){3lX{AAj0y@1^`xtUC-U zAs?3H*zG(E>4GeSy@xMBKcji_(?!WV7-bx3bzRqeXp$77F1^d#c!B@$S>0TfpLV6P z%;IF2L$=R6J=0@p@>*m!(6Y}@a;W}=esO@~mA1xV8zqa4+$Bhb%IS3>-GO!@} zyZQQA%Aikz$8-yCyH7~HOc2>ZZxJ|Sr{9JEYKx;sOU!$Z{5KsRDv&8kA z^{6?;bPQCP*fzUE3%!a~VIxrqu+4U=IFYguK@yP;i8@1BQn$CEB<}K0gJSgT}lq>?Qa#xZhNSS!Yo|;;?X>5AdT)Ygq`+ zF}~2RSNONEUd*_0@O`0I_b8rcj#Z?kHaTq<(yc{=U!dS zrjmxIe}5G8$LN=ipOD+e8cb_o^U3|5Uqau`Nt*X`?muB%egf*Q>o3?W+&QL2@LKv= z*hF59Jxd5@7kP)vE=BoA-w4_*w2)39B(6}$Ye*dmPN3ueMpR9?z-E21ddBG4j>LD` zym#5-S1&#)E(L}GB2xrXG4?eegSy0fNx<3QpTkFGUR)w{>G%(=H%rTl!mBg;!xbBP z*L`*Lrpd4B+93t3)M=!s2QlYjF%dTftMQFUIzE}^%MBK>1ig5p{q$rjIM*$wypXnr zCq`|ae%`Bt<4xfa4&d)md)U3K2-+#sbQ5}1(D!5eiiUsQU;K-3$M;6{gXymicLVs3 z(xX1tXfg02@)a@FZy<14SV8zHA0l;^GhvL}__{2udQHz<1z`y6s_VeFFZx0pf`Kvw zoijJEcy{gli{Ycv6WsZ{a-X9ip^^WFNW_n^`=+i>f`$OoHuoAr5?kq!Mb}{K5a$tj znEf=ZAlVxtMeyqI=j~^e>%aQcB$m(~l;sF=cIJgY%k1lhsW|1#A_Nb8Q`8q&5wa~Gg))_Gh~uYq;@yQ6sEsu2$?#C6cRRJtUx^Egp=p4Lu{ag3$uJT`$vj$$WW(z z{CwX?!~4>X$CNwY<{rzhEZW(S?1g073~@3HyRignmdOXn;LU63Nwu= z*h&JSM(0KM9vI&Jt?h-`Ffpv{S4SD59h2x0`Z@jtH;XzC_}-))jsG(GSy9Y=dEiCk z^Zc5Jy-k{QZT5)9a2Qj>%JY32N}6{%=Fh0*!STHP@OP@({+~z3$7L$h7$D!!lhrh% z>Tt#Pm5kb<&db`vE*yF~c`pYL`FgGQ{S;mm^Yh}CSjGI@h${g*g$kyH<{~7a;H6qyQQSG8C2V&#T)j3M;0{k5u07jnxcH4ycXt&04tw zfk>lYlhg#C30XBgDA*zGqLewttqWiVcrzpm3cCIQyGC4JFEy{Pc=s;Zg?9;?dM4S8^xrhnUUD_+= zL+*Bt2ZfBJ!*8ItXitn0hb2A0-vmBPRT)*P=%Mbu!olRxtI9QFy^4<`{1L_IHth-X zzs@HB7XOy^jCp}s&FEx?u#Yj<(!J>`sA%F^%n3vVVgQkVcH&->sB{WrCez9(@_5aD z$+%9NMpfaufJE3O2z@Hk$k!L>wrFpRFCSBlr7O$E^h&d4kG{-sdBOq8Er4~E<*E6S zrO32tvfAL*pV#ly9n{qtzA-vXU(G_xu!(6bHl~~3*;YF5LDP_@Fwy7>h|Tb)uzz4) zFaVg2iow+2T5wWa94;2S6n~Mto)%40(aiL6)?#)RvzF1r^kl{{a~PW#m#FcCLhO1> z6si#J?@G2Mncf8#Rd&TJuWReSn2|hwa9+VMft$s7%B;_(8w|JO%au@o+rs9y}ho5%&W* zpLUe7i)mzDV!ouyX`g5y{SNa8tAKfnK0slS50NgC@`%BN?YPyL7UaKhG9=XzZTGVy zrh@EF%de&Zqs-VjQDqS5lQoR-a-~5zIp!MqIvgO+7Ry4XtStlq5-!M+k%P&{^!!$WEO}SZ(0S; z1u&R4x0fGkt!4x{RFiS^#O&$WMY26hTsn1 z>Trp82ceU6jgmkup>6x`IZn`PC_{wHI4af$BSZNjy%4=XFd_+Y3sC|*gqOgia1ww5 zuER@V6ObO~0$VTmc%onDrTtE$);!iGXw{mv8nwDz?WK`x2)cQCAA@uPIkC*JT+h{K z>mv;fhNOwq$s(fzj5B*%4qBkrNZW&{WLGxiJ+vEo4LT3n0C@rFhg3p)U;DX|Nb&#x literal 0 HcmV?d00001 diff --git a/docs/source/_static/kaldi-align/i.wav b/docs/source/_static/kaldi-align/i.wav new file mode 100644 index 0000000000000000000000000000000000000000..9db292a47fd5f6463143bffda2c8426670d311bf GIT binary patch literal 688 zcmV;h0#E%?Nk&Gf0ssJ4K~_a(ZFC?I000010097iJ^%m!eE#ODm;6&mr@-6&B{IUDu{)X{0(y_Wn z#LmvF$5FZdt%jkOk%fr?fQyA2n+L+h-yHo94P^~N{TcQn`W66o21*Qb6LJvi0@vyb z&6ct|t_Z$m)ZXV7`qu?BB>ye6_pW+5cLZ}1lb9{5_cKI9-AR& z8TJY>0|*9Q6cZ+9GrcskEdb0)YS1^~vhvqF!5)GNi1!y4E4_o5k@D8wVC4{iHm^Wgg*7G^FTJiIuZEU_BM3UUL- z0_p`L5G)}XEW#?EAMzNh7MdDLBmFUdK>bH|L8UXeB<~)S9$q26CSWL&DlR6YAKV;{ zATlRFDYz|>F`zG|C>kO9AZ#OtBh?`x9*!0i3KaRF>__Zk`tAz`6c`Uo0POQ8@kIOR z58)hu98MH+4XFYd^dIdH=9l5#*ca2K+0O6S1=RP9(TTUD!T8=B`>X_-@krKy$Y8#b zsJ4&JgD;97pzW^vu%ojYwVJc7ul%w#%T?E?+QHw~;-=py*#+fl;G@$J)MMAg*+$;x W?IG@;?A`L+_s0CS{#pH$0q+HqWIqZ3 literal 0 HcmV?d00001 diff --git a/docs/source/_static/kaldi-align/me.wav b/docs/source/_static/kaldi-align/me.wav new file mode 100644 index 0000000000000000000000000000000000000000..e4f16f17c2a1bee24329276e8202e99aa010a976 GIT binary patch literal 2620 zcmV-C3d8kMNk&FA3IG6CK~_a(ZFC?I000010097iJ^%m!eEoYuNk$YyZgiT%SF=8*C*Wux3F`}w4c`ud5jqqu6$2K97Ze)N8!sKn9ljka9?KrrAN(PeBXK2CCzvV7 zE1fR=GDJ5WJD)t$J=H#IKo3IsL(oLXL<~f+LajlHKyW^yIyyG)F}W;9C*vWy9OV`@ z4nP(p*vQPz!92LauUDycqhp}Tpl_jzpW2%Jmv5K>o9mitno6C1qqVA% zu)Vf=zT(Gx&kWV7*@NFC=FaV5^jP`G{g46K1&0a>4q*^$5%>~V6%`kK7l9a{8uT39 z9g7{h9Nru=9!Vc<9|0gAA+{rOCfO;qIvO`pGCnL8CUhWE8|V~Q4U_~i``+?+>ObHh)!)gczeu+RuvDoTrG=*V zr`D#}poW~>o28yxpX;8op;)FvtZA_%x7xm1$Xw7#)&<*e->&5k?o0M#``G>#0%8V$ z3ltB85lR!Y6p0pt8O$469FQLVB6cNAB}^lMApERQabGDSE-J?21IL{vv- zO5aQtPU}uIQGQhYR@YU>P&P|%NTf#SLoq*+IcPQ4F{Lc1BNZ9W4rB$6`l{{W;M~DUMo-(=A#kRzv$HLDF+2r01;cw~p`S}Si z5?B)^3_1s51XBYR1Xlc1G0c-^_35o{Y0ABy+2R|44 zCLS-eFn=ztEa5QQITS}JQ&d-JQ}<0ZOg>EvPnS+APNq-SP$N&7NCZ8IG)^(IF}gBl zFk2~D9+VOl3MT~J{Fw3x=%3xO)1$_Dvd*6vltz=NlaZF!p**Yir!teAgvXOvy!qXB z;WEh%yR5_U+NSJ)?;G?*{pkoO3P%C$^mFy)|Bwt34G{>;07Lkg@N(^!^)&>I42TD( z_(ARg^4I?W3vv$03-SzA6UZRQDVZzkDKarkKoCyqQMgWzNzO=$PpwzPTZ36$REbij zQA1L;PKZjELyAJRL8?DXHoz;@A($Rt9j+K%4*&v|^Z@F8;8@oY%ip>RsF<7WmT8s} zl;ox>tu&||jq z{|ftI_Eq;b`%ne63dI4r^|kGe@n`?B2$cxu1CRku1gj63AT=zHE>kUaGp0b`NmNb- zP=ip9R9{<>Sz%SpQnFQ{TmxOPR>MpnMgc@_MlVI%Kt(o*EN>)7Ak7+)5~T_D0G;>T z>rdR+&B(zzwJ@tUpO=^xk4ud^9K1-|M&%` z5-lMPBa0*lC)P5cKoU*bP$p5(Ps3AkR|Qx6S*c%;VFF$>Rg_M_P&rn^Q_M`kL)tv9 zH6Sp^EA1xF9aR+^3q$@@@+s-Z-yPHc!pO1nqd}g6nNO8Jt+%gLopyuIh*G$eY0`UvfRFxB){E`1+21*lg7e*FT8Y?LXIy^)GN$^MmOY2a=So2x!TDV^SU!z^< zSDjR?S$tm^TMbbqNqR;_L(DwGHfS&;Cd3#?5AO&P0f+TE=8V=y$WpjptSzNmouHEo zrz)v7nRkHGh1IYq-00FawhN`WwYK4#_e1d}?XTm6i6bmHP1ejK5aV5 zJeEj$RUKFiR`gQGPO9_h(0;;+*9F><0lE z4Mq#z2*(I44xtKX0!;(K3H<}a0(ktZ^Xl!>?$h^){#Ex^=!o9H=~Mj>|Elyb_q7G! z3nvXr63rfRE$}vAH>))PIaNf6PXSaMQVLHZO^{LdR995ZRAg0wPyk89M^{I`Lj5}; zH#IYKDW)7X6r2!p3Ap-`>95{0*3QQ(xRR@pqve`&w9v8-mjQqjjAXnC+zH0csQRLD zzh>v!^WWxL;ke|`?ra4h6OIlx2uTQC5q%I82n7os5H|@o2PFab_pJ2A_bC8T09y4~ z=tScP@jU`S{0;A&^)?3H2z3cB5NjMtG_C6OCu6;BZ`2Ho{@=pj_6z+lINkqHnAm)@1R~>xSf;;Iior_~Hv$5`qku1k(zG5+n`02ICDg5BLWQ z0y_G8^NIC6`d|Kp`ONC_;Pd1<_|*SS@&D|+`5Xl71IY<~7MLTVC`2g~E>kp)Keb36 zNV7+vN03HBN+3_=PD@PiN{&eHLX|>wMesr9HPtTNEb=FH9i9^b3?KtY_&Dr8;*w&)=Az$s;(hJV19=e@5F!Tu z1Y-<2495lp4A&3|2>Afw_|o$c^(6Sd{Qme(>Vn>F-*NC^{e|$?=m7Cq{NVmx0|^hq z8vq||AMhe7EH^l&K+r(gK7>F8LEc0ONry?7Md(9BLg776JHbAeJ~A`DD(fb~BQqLr z67>wX0oL>y=Re*a*#OU`zf7m-@ap|Z z@rLM&>#p^|{H+4=41yPI7Yi6^9=0ktHyS$xJC-=)iu!yE>smjKaykVUbjINwI(%R-2 z$GNGzt4q)??>O_r>v88X>89`q2m25o5uXQv1|$vH3YP@a42Tl63f%yv_to_P`u6-s e0a^MO=^Wi#;$ilm0zLNN>VWGT_G1C$3hWWasoXOF literal 0 HcmV?d00001 diff --git a/docs/source/_static/kaldi-align/moment.wav b/docs/source/_static/kaldi-align/moment.wav new file mode 100644 index 0000000000000000000000000000000000000000..eb60e44fe71121dd2d49a376af41c77747acf350 GIT binary patch literal 9702 zcmW-G2Urx>_xAMdja^{ry@-H_fPi9GL}H77gGMn?>|HciqQnql!QOjCg9MRa!(J(Z zC`D8Rr1!qCeaijr?|beuJ6rC|xpU5Y&ikIbKSoWNQtN>rG2?%lynNjjXE}l(C>*=T zz&E=P1ViM=(#2aAKY?@jPU-eb?KHOTHX>^`%N&)jvQWNRxh z`1x4PjKzCvBWKtt&iEdI@H-)IW#za zbF{a&R3Eo?vhq?bRE(9b7P$+Ws7Nve=V4zsHJxPGtG%vS+N|#k7XRo2=Jy#J z_kaIdW&3U9w|}eut>4%BtY@CCk)Dg|1Wl3#=`HCxnL_R%ua*`_u1hw{vMmNIcdIwr z1v{0v{O3HvF~jzQYMwNXvS7~-x_7s?OlVX#OmAU&?wJ?xI+aaMA%2!&TgJMNoiO^} z(EY>XJX)>p3jYH#eM?Vy>x@RUskSAh{YBTp0n!|V)Z=5R)x2lCcVq#w$M~o(w9Qz( zs$}DvLl0eV``utKeY})@OYme&Sz*J_!D-k&n+HSdg0YYdVcFy9G0r|i6jjtRe1h?K zgInqT;@Z-yr4v5r%71Ri?Oie0s_y}`AXL$-_(6V5c1zi=USNC3{s(8HSCQZPVQc;0 z_;wG;bBk6l6g;Q@=sVq*RKb6Ly};@n_P($rwJNzC=nqr5)|H;0L*|Y@HZ^$a=c&^t z{T80+>*sV))k8@P%R3!f!W&~7vYHaw7ip4AJBU%T^>%LVwce#eqlZ5Bu6EgBGfm1O ze`|f(fBn9$%(>w8)8*My(w1H>I3+)ld$H(_)BF9EE!|5%ms3Hg=j7@s%=8tDZvR?5 zQS5qN@weO;^wh*Y`}QdRW!m$$Y~o?#N1rDDc460C`n3GH+F(6Nm1^tZGQ~a8N#n4~ zJ;nF+&=U7+j$<8)Y$i%v={KEg>S8~u^ZOsQ-79-M@pVhSU1?Us1jB9NNA(Qv?ocYy zZH8<{=k#IIr%kO4Z}m>HKdJ~KrGv6|&*tBn-I}&FJNEdPXHt1mm5ttW|A=WLOM?Fm z`eVdSzw6H9ENaOx%-D{ zVKH;(;c=~-9e2gfO&by^8X>-}8(tWBCI1YUT5_%NbnN-g$D^v&G%jc_GUcjoxSVjD zZ}-S0b4b6(0f#D^*KV4j@jgZ;U>onW+NsRq3iGwsrg>CNVPRg*)_dsFJNcX6$9`$( z$e~_2^bA#m^-LW$A6t<0lh3TJvJ8>s+h6OY`5zb4-~wH^6DD1L$t= zKYzV!@enWRsoV>KJZq-w7Zm&o)*RpV>#pw$i~VlPypYB1agWPS#qW(c+;Bvk5^$6h3fwW-`^TC&m!pHH&YLtpV$x8FjaWWRkbx}I zj?kC~#tdH8PSHK5Q2X?R4yL62AV z2{sLmn46bVlzoCjhs$^SO)?F4Q8T9XP30d2hjNZQ5IpIA{`j@5WK&Bbo2Sq@`UcdF zIWm<$%`NiP6y4Z}V1d_Cr;Rp;Wd7s>W-X^gp3=v4N&30?NQJ|=hm>%3?7cSoj*&OSc% z`qs}6FXaASa-z4^qCDW#gp;#s7w_G?E+Kj4?uj>tp0%t%{_k7KJ87vr&Z#PGGjD@LNDj)uJoN*MBw+f|omE)(rDm7GPrinaVCjo=~31M}U% zjqQ8CZTcwqu<%1uVQ|5ek5j+X-HqlD-eH^UArC_wMpuqkO=+2;nG_I#1Xm4x;+n3W zD#@WLkmvL!gTFZpki;2bx1zz;&tvH@mylWEv&TK0ba&F>vG+r72kLy5x~x;(qcWK( z`jXx?Ez|0ywZp$JsZOg5sZf+q@5twMS^8@+XVxB$IHJFrkbOUI{?{U-#_CY;tH|~_ zdB5bYYFTYRCn8|H`W(-!`QgL$6l7rl(z@uh}*1Osx#52V^(>u>)f$b_w7sXJ~CM?GMaImNEMHk-kd)?)lkyRT% zRTr*()s$0}_4>x-3)++0&dg6^9y}?GubyE@wjlk+j`x^xZGL*(##LJvC5$)uMB0hP zQQdyUdvnBh4`(j9?ssLu?drEFpSbT)eF0RO8gqH|#`LNmBPRo&R?Er4=jQ#@KRm;4zx=l{ zZSSR)OCalWA^mMtYn7Q%I(aSee-&_UcyYkap>n5cxsE5mx3KlPJdO2WnkgNKkU7Xx zqFRtZ8JMkxeY&Lsj_uRxa%-CFI-8KD{Atjz1X z{bS`8aWsju$Bk2sE6f3`3tGP?tELcYfC1ADkQPZ`xh4x-2^+c*PqO-j|e#Z3P{o5=Einl;ADe0D@RX zd${e7+COWE?|)P^e2w^?-{vv6kM1L0E0#K*^$Hrg!!K~yg<hYt&E7@p*B=l#&xV7p1R zTylnpU_|;qG~;{Fw)Vz@4RH-f-REk$GV}AQ(gTI$8@s1}-d&TfN^MU6H9PIu+`Lm? zUbosA6NK5e=e^DasVDqAc~fYdpS932)A@3t7p z@8h-`-u54E&uUQB-l!2)$5$S!jH-TH=hu)|7hLzP?qqFx<-M}P_tkk=$+`8! z`|Z4srb?&AN1fLglklZdt)A*p6d2@R;(E-^?3!WU#s6(!VfVSNuhplEp5!k3F!146 z0abYB)7koKErxCnZ4`4ARS*$8AKqHN4=DVu42EKPjNb&Mmv}aHWl58 z$(VjUPw$|Asfp||bk4 zD&Mz5#&}$Hj&}UjZoOJ*X{~y$+GFWtHP!N{)ns*$O{R6PRjT!F>x-5-7QZW&DIUtb zq&i`nV2)rsUqaTP@!%X6#s-;_^+s*G?uIUSAino|_rmV@jt@<YaaY02x6<76g$GNLKfnGS)6&KjNyz4%)vc*nCMRlU&hH= zPx?z~@^AN^+f>zKF5J8_ZUyG@1LOO|}eul%F>R%24fiLSMstNX4R za_C;;B*Q`TB^ouz2h5uL1J8P0yU;GzUSr=b%`WX_!x%G8z;#EWi*+zv4 zZi_LQP(Dd9OQ}_@Q-Y|?DS z_Ez@xw!Z2-t9Dh5#VUEPbh!k|M^T!PFZ2@}8t zfa+_l%nmZdR8t1fq>|EJp-*a~$+$f-5F^NblX)kCI z{Uo_7ekMeOyI{_%cnWeXwTR!r-!BN0{9xg0>1vg#c66wAn&*_@sBygSWOBG*3#@)r zjkXw&wTMl8IR%Kt<7sf5#lwk4qLs)cc$5#{SKuSa6$}%fl|GRTlOeJa**f`1 z87E#VDia3hJ4IW{G*Wx!zP@bT>{nxzmwMJd@AFa1Mwya*C^GY?+N_ zk-=rau{XOjt95_<%BuJ;5uayP6qS!F5BTC$^R#YP-HrOE%~#sDwKunJYdg`A)S+xo zZvWObtC#4z(Z4{KZQjoXps7SE|GM~z471p+BrVTdDb>kpQmwPzVdH2QXjfsIZ0F;^ zIec{pb2K?dIQ)WeoH}ug~ zj6G*d_z9*9Z-|b`LX}%pTGchHPV0Q@Yt}%WWy4bg%SOv&tF4xkRGU)^b>U4z#L8wdAlG2JBHAG+&$rE#;-*NB=Pm{BH^DQBa>Bs2!M zpoa4a;TiE}*-wfkV*JSS481T3K5^wGvn@vn*9jQ0-Bkv>2!O zQ9eO7Lb_2>A|}M!VE#`Q-A~GBVkUR5F?^df@0ooYCW+Ddx{via^|sV zx8c1`26u({eRbVCJLk4Lw%WCtn&axt)zhmc){LpQZ{zps2Wt$gx%t#ip;*))+hOgm z-Y&Z-I434#uKX(QA!k6l*k;{hT`0Yn{!won^wM|gkLyIbJYA6PU!9G~k>znGfR36V z`!2VW6)Iw_+%1<&n*=I;5bkKcKJa&+xOaGWMMp~e$Bu-a8Jd3vEC+AtqD@ieBpB(p zrj^VbHin%H2C;+GTK-3YQRE>lk?a=m$Z-5L+DJduN;P$wI|FtD1A_wNYVI(8i61Pk zle<~|qRvuBs{?En+j`sXQ;$}wY!BGIvVEuCqfSvjwwh&OCv6hm=MR#uxD&?1y~!%7 zf$u0jF59FytLT=mlh2c1lnxga^HQm$lqZ>pugzs(rl}LZERC}cv58T;TKih*6}h5#yp}y~K59-ftucj~BbXxdNrPTjsc+ZT z^)+-4>F4QUjXiWPwo+&zyQe5q&Q@_|Gr;!8PKGvS~ zrH>jrA$zqkJGsrsVj_|6C%h(T;b-y>@p1_zx|e&#qG&C)9f{>SID4+$Txyu4{b2wZ z)N0MzUflw7CfmZ*B1$5a%qEl=ijKjv$uvR&+3YbHN%@oMcmS4-G;t~HDRw3(L)M{r zSP~IOU7_M2>+y&^WDQkJW%Fo(8RD*2Dwm89qC9{mv2)B$###Ds{b6G}oxz17M?fMM z&(xYM_059?nlg=kaH?UIIg<529FXte2sfEQ4Sx-+>|4~G*-_Mbqsgu5YU9TCYE8R& z2N#ECP}Ra8CF3QZq|c-aWnX2FWV@s)@df^3svFAy%jju_x0;Ur3EDNrPE#$t1|g~O z!r_t#nNiNkUPy(q`|_)DOW6|1Mp3!Incqw#fiLD5V}|~;cEZ4RP18W9Hbb9cTyKu0 z%jsw~llvRcAQhb9tT>jw*k( z%<=4AkcqkyM<|qE&Rfn?^Zt`9cfjp$T5nKQSgw)Xmxb}&aagwBd%WI)(Q3 z!0Cbe{Y!g3cHCjK#UCAB9e5h@ZpFi@q@OBX2!DRDx z{ctVQf2#M-z7>NxhHK`h^k~M4`wC_u`;g114&T8$B2-HTBpW5Sgd(0legK?e45mQi z8p8tPc~hndHw74@OiAWdv@7Gz`JiXWYUDiDVa9Nfr^?h*Tn!Or5bXaw8>atIBNJ;SEavcl$aHC484pAeL> zSL&cI)ScD47(bel>4VG`E(7tzS}`S|ruM^FR0#S6sRB7va+Un?ynAFUk&FeQEue#2 z$vtEru!$TG2|?eY+pt8OCccpI)O6lzzPE6muvEywW4cf!h~(er&F2MC3kVVZ9t}qU zNag0TTi7$)e$JK4W&{kNufeBIfm$vG`j*Gk(s*oMN1A4AL{&aoU$!M9=D< z8_pVkGq~xWYS(E$Ko;`SEa_j{H?K$BozOYGGrDVh4>uS^Ujf6wcuz_g`X-~^J547hMu|6y zT=^REF>x64L^q*}5qEYzZA&+r*1^8cS3gNNO~*6rhI(zA*`3Ma`oMWK0rp~{6pue% z_zn(x(NDr}{4=~0lrtHMPeP|5Bq9WL>_KK3Z7{R+6_{<=EDE~7RMZBaPTV6#Q4!+b z;$W$VyiLl9y=32{m}I>8mgJD+XHfwE9Cd+8qTH!)Dw*6zd?cbt1(Ab~!FAX+ zJdsEteW(bkm3%-(QSsD9xPAMSh#}PEXmTx4i(kN3;1lq8JQgY!8{833#^dloEDLMI zc4Gn9B&-{Kjx>OkP`%ED=(LAkFJ_Zi6&uM`v9+u()OM~gsw+V?SINoYYbj^oqPcRm zhdsyg*l6Yq?MjpMX!A)^hAG)J!4zfMWMYk(MjYy$-wmS;=k*KpHTvU*Afw2*#}H>Y zY}g9%DT1*+Z1$l^b~0DR#&c<)9%)9Vpcz;ro`px^tB6iwK2eXmLwx1q$MI4;i5Npf z;c3`&Yy*}JtF8d^#ZXL(qL>pl8~q*GgD6llWC&Q7(V1vBQi9YV-N*xkM4zG#7>SL> zj$;{6)5wUG5SfwGL@JOsoIja=fLF@1=cn>-@=bg~po6{d8{T|=IzO1N;xFN6@hjkO zYW^dh6>kp}L&Z@hGK(xHTgl_(Op+#65by8+{5ht>nz1;{gw~?>pe~FGSH%4c_5uUAio8V1pwird{)V#fN-YwKjzOK!SadY{ z7P*L=ht*XA|LcxJmLZ8q0n&@$s2zG5xr{U-Hs}Xvp#XaAe@e9#nS&<6tCa9Q58&J^ z1VCQ~LHiBRl8O7m-RH#I7Ir$D!8Wm#5F7sRI}1?3>A10A5$yUhz!#`uPQddKxY{Nz zm+Ru9;EH{KH;@A<{BA7#*hO41 zrv@`&FRB6I$Qa1gHxNA%i^{M)*bZzq7KlxU$4B%ex(7b@Vps)_quU{FCZR#ld(R;w z*9`yQL^fA0sEmR~=As4uwtwp_9YUSxhU>{yW_psNo)mNQx>#z7+ZxUV4JZ-I4>D~>kCy{0#=EI z;%o53_)YvdR2MtpHU028E@KzrQ)FW2;rjfsezX+bjfS8$@Oh`ABwB_TArH4B64U{m zf-Xa^qMfJ&bH^O80qBdD=w^tUG>F?Ga0l#!F|Ffh&Jn1AGkllIz2o{gG5j%!T!Pm<}qZ`4#=nZ49*&uS4=Tvm!;6QfwkgBz}%Y%wt+-29j<&XSPhckv&{w3 zuz#0>XApNfww%pmA40^xVxO@c(8efeVF}EIOI$uT5)6m9?t&S>2ex1$yvJ~OI;C@joxna85RN z4bD5qC2~u-U>K8?@F_;XF%y1{;R3no+%cHv$GB?9FGs*WWG6ZmVwxsCiK0Yls5^ax zwZwL!b>J~#hh>5;sD1ZxUtsiTq!NdFbkGZXkS8FMO(&{_6L}l3K$7LvpgGtZ!3ohu zYBu&1?f|Yr8n`i75^jlE8=n=#P&e^+Bu>o%BUumFAEe{C_*}iER|gNBKJ53+g1^jHbgpe~IqK{YjMg0o5X{=yk*ZW*`bI z5%x5jutG$EOhNX*s%(iS!n~znEwRMsBOTx-{srHF1|kKJg<{ZL5R0D1!jU)Z6l63a zhT1a%ddZX9$IOKP6Iig9pceZAG4O&WnbGjO1#~n0hHYWrK`pVK?Si%L0h7$c!`!+L z@!$*|atqi}`UUK&gKhWTci(6Jnf}cAe$I1d&cw}+i7DX&z>3J_a}&4kapD00 z0Gi9undZtb0003zVB`8d>#v&U?B*nKxSTs|n0=FVmgUO&nYEl1$f~nA%i(cp?pB_h zF+?oIDhN-8inxKhW9`sFkOll_QW&xfyCuLoU4m?&gJv@x1JeN}_!u2T((sK$7&G01 z$>DR$`68)CdHIW4s7u-Rq3*J=|#ms82{;a%W-Wo0ohA!Aev zumX5QT{k70QVlKIj*HP@5F-k9H1TzfZp6Up=YP)Pk7*XhZ=q9uqx<+1ACgS5CT)LoSY%Ucxt=|pd4iZp++)VErR;w#+L${D0~|pgqI1AzqzE6u zcHp}pAvM(~)QHEl@@RQuUsr2w^{oPz$8e7E`L8dxy$S4&8oD{`*jn1+)L&wl%*k@S z5xRJ0bp#rGc+z$!cVW3i>5~=Y>?N>0!>Ak|>}~y&_x`7b!cWuMXM8>Nc}fjbo>N{| zTiN<#xRPw(Zk2p*2ikax&WMH`;(RZJj7>it&Ym3Wl*y|C+a~-p{^Nc|1{GzxsQ1(e zN7Gb0^v)P4I`6pOXNrI6q&XhDTr1tuJtKS=zOdhj*FHzKFq@SPo+A&E-_la(D8u> zU@f&Tm^?8pC}6;Gmf#HY3UC&F&N*V8z>Wuk)mwYNX?p!;R$cwazq=$uGZa>1r=`x_ zu>&e~AY3hQb$=OjHTaU>uv@Nmp0LO2tu5tPY~LgJ0d7<_bT4WCqp{-CFJHd))ei5I zhmE$5ZX6deGwl+mJf9<2v?zYTGQ+}(sMq0hLQTPB;J4oE>`t;gbmPO(z3aQK_EZem z$lq(m$GPKu+MT-H)KYwtAjrmOzrp?=i3hJ3dt}JaGL@Y}7u!2(lb^r6ZS@;?H22K5 zD|_#^mo4iF(!5aC_P7t{GfI7z$Glv%Flou==PL#x%L8K^r*hml8zj?&=S-!Yv2T{< z`97rXO@6YeSYN;T3p(_T{I+bK;z!_v?dhP*=#oYMTN1SZm~A`xxTCFb1>b_NV9ueR z4O@KFme0-q@?YSKc_k%pb6e7T_sMWYyrz|o5j+au#s0QFb@SRJ)jDI`c;p|UN&YV0 z^IWEiYfw~qzI|JLQEgdc(ig$dIvp1wn8|E+{wcAw>(MFCW`su%O;-o3a%;2w%sU86 z$!7AgezNMf-ns9Vl(amx%D(rn_;SR>D;NL1oRHIA*7mU*X7Q4(N4P;k2j8Xh{*9lx zVQjT(>4#a%ydrtCCgk1UH#fABU&d?xc)9Za^?#ma9eon_O8VJKwS?vG5bTxjQ|sdr z=rD8Vg1L*67g4h*XQ?NZob_R_D)_K`SxIGcJ7%c8M(UBXCGD$MC&&*KS~&G z3%rD0UXwII0pao}+iymqfN%@nSGLD^>4YcRgft_z=yj-qjyF{qTj=w|C*B!RihZYR zp2rc7b?(Wo#cr=X6h6b=J6s^~B~}1<&%l}ZM_W8D)RZaZV=ITZ_3!E+ztntn>AKa` z-ILxivz4l4R1Fk#J^A-}U*)szG{$YoRY5RWsmZ3>Y;Jfrhpi63J!4hKpB{GH^Tw_H z!d^#NxG__+Z*Z`=v#z$f zWVRbWrL<&p5Ph@y_IBQHTihz@yd(9|Rsf4Ewp*Wc$#56BQuZNs)i$KVGv{jiG|3G9 zEM_@UOJAeA!BONqM`pQDcuA1R_Y>p^*NXYpjn)QRPupnmut3WRWVJDRm~+@coMLt> z>jPsOu?T$t1%d|3g?eHX8#n2Ws|%H0irg`cx@&y;M7ipae4Z>+T0AI{+I6q~65q7_ z&G?JByQ+sawSg0kZVzoEi7e%0IK|F#lHTV*P*Yg&%q!lgc%EMRIlHXjtzXyip^BEd zwR6g4m6fmk>i=p#IQo@7VmZ_8hA+!?y=0nrrgN7k*ZaQvCfjm+gC?-+(Z@YCrplVP z|Fn4YT^`8m0z1KhbQJ}j6g+S~;D0BiHcT5>=%JP@;%;GWB$D7z<6%X3SIPVJ)e|*} z#?kh@Qbg~-j1m=D&vOiPkM&EPLWU7QKKvZ#@bHC6#Z#zpOyo@~rDp z&#Ny1&5Pe`DXFNPK476gi+MX`Oxqfn7GP&T%6W$GMkCp`#2m{a)6S0k`o@Ym`99Aw z^M9#b`t|AH^U)mnu*`XcRPV>8+ogN?`X>gs1Y~>M6ZbH-(U0|~4GTbDNIelWaJS8& zJ+R}a&Y~e9so^5_FWgs7y5yDWl(2D#V9EHA~+(Bse;|VCoQ|m4_H@dCP3m{Np2j51VT5M6{u6 zl_Qnz_ags0@ZsO?lil;%A2udbcontRdCUHT+t7~|Z|#=5ZMQ9E9|FovMZgMfgM+|M ziMaLVH!QBdP#^W7rX`}gV(9T$r`B~mK;5TYJE3O;xPA@U5wUg33C|vfpRChFu~vh; z2#Yt&CgL5vX$c9o@?5rF+ctYNR)tKB-0synHFX?$S{p< z{EU@b9AI&I%obG~n0@*{PxXgNp_;rJ^CL!O1e4k(iO{;mqM>(x#rs zt_5Ga+b4bg?{nyAw9R6_W~*BxbNhnTP9TT-)FILr%RR_Nr$ln)qWah4RZ|a z1NeQ@589;@NA%z6apOUF7e{XPg$Rq)@Iso4p5w2y2VK@Vzq9*bwS=9@%pfwcQbdoa zu}}CYB-EEG{vMIaFN~3zSB6Qn1K>|z0*@h9_^+%mk;*>Zsm#tzvVl)o__Hx?6}y<> zfZhhe$?x@VwW;bYsx*zeF3C`6)azr6{)R7f0{WhjL)^kc(Cf$~aM9~e2?Pgu;Z{%_FSGMcOBU-&ynvRX7_$0c(Abg#&h;0|mq)1OF!*3cZ<2il1RF{WTMpj-$<)`GXF!_*biTz#m4Lr(xN zsprP;^~&-0<4(p?)B%!V;#02x4t;@gqur>?<;JS%6yfdiD7eD|=K?6u0bPD+9voC|e8rBDcT7n%dVg(FZu+>L0#Ik*H2^|Fl zg3aJI@OwZ@djKjR8We)zz%F{2{{J;+X(zMyM^l)&l?F_vJE=a})qJNPg`Pt^_yjxw zx4>cW4^Ro9p#P&o%siXuddiZDp*-oiz|UYSh(OmN3i=UFhYMgIWA)tQ$bIt5JbU3 zFb`^ja^cTNH4=@)!@t5y;D^8{In~TJ7(5B>f-x9_T7eh9J|Gl)4pxH!paU3h&i`p} ziqIxu+#iOlLfTUaX=4n#~g3O I=pX3+0c&gomH+?% literal 0 HcmV?d00001 diff --git a/docs/source/_static/kaldi-align/this.wav b/docs/source/_static/kaldi-align/this.wav new file mode 100644 index 0000000000000000000000000000000000000000..1c8bed9560645262561c2caa126515407e1e243c GIT binary patch literal 5194 zcmWNQWn9#U1Bb7F*WFDcNTUb>iU9_SfFLNFf}8Lk_M9`OSlBbLXQNm+8)Hrc5d#5f z0Xg7~E8x2O{`)*{zAv74-_MgZH#Kzz6$Dz6k~@3VnoU9?2m}KE$FAi6#tsk&0wRKb zD%@1a|IZPihhT{?w^s-f0qcNWusUFqC&YW!dB=0!bJ=y*u?58Nj(Wx%?al;yr}4FQ zi=E+E>kR|>K-$0)-5BSP`!{er>&0CbL3k4O_wbi^C^?v+V1jq>~f;*Sq``xxPoolT( z|7w1(JE2Wd&CzBmev!=+%ZJ-VqHjC9Pj@YAJNu=nky)2koAZkP=2SDfUDa9sRoi;G zBc}i3V3Isc^^f_P69%y&e?^}G!jZ#REaM#KG53ui-$(32w#OmVYat4Jbc{I+S!a0HUfzu~YBJV{TV$IRm@F`J_$vY=a ziTg8No)8)r6;&6hnb;M0+SkFq%SZcHu>QbrMcsk-y6=0k-SZtD^FBkiu6CTNWDS$P zLHe%s!aBgsoDZq>C!armLc9;Y(Q_Jm@bj71uhSYv-lLy{KG^#j^SbzRUC+m{K3f#< z7`>FPW?qL=0X#J-AU|w0cEO}=2`y1?gkcN}Z31Zyejjlk-Ad_(F)g85v#d)*82TVS zF*-OVm!FplWZ=PVee=YN)t4=6+`qt25_{>d=|3@^@Y?)YvGnPpIny&|W+rFEWQAoX zW*%MmWq$Oc>FGCT|1}LUsXj2tZ?BNf%%P^&>j16mruD2@YDzI}Hn&<-R)%Sz zsnW1p|44g7d31D(=9v74K4<&75C8mAT)pqgvy-%4H_z{?oA|D)Zq2FME8;tQYYx4F zv=ofsEuY{N&p3MqRcpn%i`b{4wF!`TMpEgFS(65RwsA5zU9?EjIT||v!93)=re%)a z5PfKy(!RPkL_A+)?YTNo+4KG5!}ssLP9Gx3rjMU8kAQUE`<@EKH3E*GB4~_wJnX9eF{EpD(CP0Y&ZzKH|5T{gE2O5HgcUt!Q6pfoqy`jb)iB z!De=SaqDb5tZ4>_YN29W>JbHWHhoEZA5`mk7JDt@BzTwn!lzG;ujOBpuP=Xltf8uP zN(X18)FXhc<$w0K&Pa~DO!yx)HflpkQ{0}!h6TIkKTNzGW(;}73A|nOaxuBZ#M>*+l<9} zKr1rlm;iH*d8_%G3D8iLOT>zv%gw4c&?gahXI%Pz^xl!(pEEiRwLkxx@;KqWw?3=6 zzPH>{gb(8#_GwIjCCeGvG~NVQ^1g)8gp|cc7ao~j7_l*Ao$w5H69@y9V}iUnPO-Dg zF(?gf&Tch#z&aE?p@a8ENgA}e%{1TGp$F?Gy6zDkb4~>|MB3tKCz<1ap84AhTw48n zZ*FSAhP;Hls;pCst|n$g3nzsC*IF%S1!phqFhLI`SV^X}#$_gfZja`w`jSSf-K3g4 z-Yj3HV2&)2-&GEd1of0RM7}tEbN{8)Co2w>|Fx|(KwrB$+kpmH;&Gf!N+E)_vziLJ1VMjqXaC1 z^!4`c?mpU{_>ui~?S0}!&WR-lycL4_#O8>O*tf2hwO{9b5;bLZe=t^oM1ljsWpVmY zD3i~g%-i7qmGQ}URs26m6B3>Ub#RBhP3A@8UD{C1z0qM~wx!yZrg-1`bDN;$W=nH- zf*7YRv!g&Mz&+$cKn6Pl94C7Mz9m8Dq-G}M)va*m7N?giVCJxXEGQ}{?)izad{XY6 zw7kht0aMv4>62(*85=2y#3`tm?jMceqivF%-*w*^lFd>-IZP>(D@67_YkN!E{ic8q zvUkKc_Q$WT)t|p}D*u7$eNyWWEh+D#`#kOPu2r4)MjhT==qjEhXhqCgUIS~KQ|PZ2 zOk&T9m@y?_%AAS%01a(D%xtMKAdF|!hcriw-_0}~NaFKtTj#8=ZEf?q<_~kXE%cC(K+n>0qrq!XMbGOuy|)SBFjG`Bcn0BFgt#UWXW&YR~M9~U5Nz(KC+V- z>xm%jEu0&B1hX7cX*;IhJk~x^E`Btkk;cpVM<1%*${7+#|HJN-w!vn23-_bp-)A+l zd#szwuK#>z_N(TGjZHr`MRZ*4#&#X=fe%Ki3oKXQi^(8fmRtg^^o%2aZT>a*=OY!;41`hi$S zJ;S`jKkU;Ur2JD3T zA?|~G99x}f-bH$(^t*1c@r+`@&@8pT#&1|Lpy{;?OFB<~Jo#zk+m6Q}FAu(0UlZ{t z_}PiqzyB+%X?R*zt9_UMfzbN0Z{5g!n7{Zjywns5sVL~j`Hpa{ryt{?gi%g z7xK<={}gunrgC}wb%H&DJuCyAPffzUC-f4^aXr{-^aiY%u$1&a!VCOHN(1E_%YP5TA4rgK!%p|O8u2b zmASIDa--tDvbvwAm@D1czkbj@Uafj3?=uV=q4tGFuD;S1Wn8aQn3f`50l&kv@XOS4 z)@N)uu9&`@kc-MCX0a1#pGcqR|MIqz0l)w&z+XUW@V&$z&`IQW!eK%^420z%S)Moe zOO$*<1vrxLOIv~wl9Dj1@ee_RsBGGC=3C5EIC`1LQ!q@nbrnMzPU*I z5fx)?98Yp<^=RV-MYJe#kPQ2*00AGDT^?gc3le zPqr7scj)h#jc{Mf^?sSP1lKN->8`kU=|U8LSrnS5%I)?tQJlobX}4aK^(k636>6*= zJi5h5)<=1(4YzeG$D7r1JKtHY3R44+Kg>slL$w!eXS^Zy;{&y*Llox7)1H4Lt1(B! zKT0n9G;(dyOm!+fk#kjXL30}8F9=eP+8%HZ=voQ!V-fnN0@iSaW*@8(ZZVv5@g3E+ z=Z<7Zf%Gh7HR?IA3?Rb`v8A3(@T?ICS`oM#;V_=&cKC`N?=3g*#U3wuGpdNa(HH_0 z^D4cEoLQrhj1=z-U57l3i6Ea*#43MPt2L35j*}$J#1m=Miq?xK+Zkp)I zV&87oc2Yv>MzY6(zibC9Si3$f7I#sN_G4W*@HF9?p#@ff#fL35myIM6b_5NJw~dtr z90c=6_5d3xw`Jh5Z$kWtNr@!-sVeLx_G!{9G7kHR;#2E+%xY)ja1l_VxvOl1Z85^M zEy4%=Q#6sW_r@v)S^}bwOFU`dHlsI|N~zWH?9{+0+_?_lTcJ9Ru1*U20qE#~3(7nd!Sq+6nrM_yo(g zYF+85pTQsOXfQ04*E3}_KPbYoZy4Mo%ICvM@&p`XWHCxG-{*LU-qz ziE5rF9DLh6qFMk4oSDW#GQM}RyAbq*BpcF!zi>xv1zI<-hx|JnAr2DwLeCq=@qY+J z5)+V2`CAJ}4uYiQZv*^C`Wm42Q`dkPe|dt5 zRqG}`_i~y`LQYPI?@Lml!lhI^wnth^Xk+*Qsehu?ic*}rB!!yESM4C90^2YHm);;i% zMopYpvl00$K3SA+SxLKw>=^nMJ;k}qIMviNCA;sIw~M2-lzML9A7VucuD+D=mqmo^ zRFh$6ksQ?Mu#TB5Dh@vX*`at4Rt4vYpOEU!>CSw@i(aMWHi4wE;R^_V_Kx}ndVY7m zk>F>ZS6h`3Up(iKW(DjZAUD{kEbi|9`~I7PpMJZ=p2D8%`l|2ozKrN+EpI;;w8Hl$ z;k+0N55Z}iUYftrsD)KR))1vS`Q>r> zqbwxMAh7;U@@S5FYlEaCZG0vAjU%4c*ncr%y;P0x_s)-k485VY@eS59>PdugTEDlb zCy(muo9cf2vD?>Y;fb0Y(fXDBBISqWG=#U*HUPRkB&F;LNv$uII}$&k9Zly|+SHfU zO2e1l#_8Qb>X(m@Kh9Y_a#r%wr2Vw2;Xi%Wqw55*-b;}6`d?zhM^{Nt6T|#MY_F7? zFiWQIA57K7Qlx>0-kxP$3I4A`Wl`Wb>N4`Cj7)FA<(n6w+fnbSDmW81gdwWqgV35+ zU>^?(PbS;Ujv5YfT z5B|sxg4Zi5k3CIM+s8R%_o_##uJ`__zx~~yp!GsoYh7QYUx8y;3 z?9XTwUNw0;_>rW6To5$BQ7n8-+N3wokGVY>6>$%lHu&9H6qix|F5)*w>eyip&TChl z##lU9(rJB^!|F|7%0=3a0mJ!Qwvk=Da=Z=ar-S^d_ z4`4Zgwc}r;{@xqHd5Wh>BdMK`Ct=v{Gv|w9xK~s+oCxtu-z*JXbq9Wny4#S7-tF1o z8SvDBDOi6)CFD>#tbP zapOG4+ZM7~3F{=&F(Pn^#tci~-_(r{9wwN%y|OEYBA^c9wY*h9$x7%ta}{VI^MLNT z{Tp(hyU#e8Gy-b$R2db(cAErt9^momnHb9x8G#3E%m&o`MKHZi|{)ntaZc& zvL6JOd;7uLu?o{+YdE9`7wwvg+;7VCR1+Vgmg?{5qtQix-7OoR;6f6!VPdV&{-cKh z+&4yA=TLUIF1n{eN4~;4IAJAFIKDz)u)w$64Ki75mb*sv~0(}f?aC`Mb=(EUr zXp^bI-rxCXOdfBMce-mzDWqD literal 0 HcmV?d00001 diff --git a/docs/source/conf.py b/docs/source/conf.py index 5a534e126..ded6977ac 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -98,4 +98,6 @@ rst_epilog = """ .. _Next-gen Kaldi: https://github.com/k2-fsa .. _Kaldi: https://github.com/kaldi-asr/kaldi .. _lilcom: https://github.com/danpovey/lilcom +.. _CTC: https://www.cs.toronto.edu/~graves/icml_2006.pdf +.. _kaldi-decoder: https://github.com/k2-fsa/kaldi-decoder """ diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index 2f4bdb3f6..f3d2b0727 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -34,6 +34,8 @@ which will give you something like below: .. code-block:: bash + "torch2.3.1-cuda12.1" + "torch2.3.1-cuda11.8" "torch2.2.2-cuda12.1" "torch2.2.2-cuda11.8" "torch2.2.1-cuda12.1" diff --git a/docs/source/fst-based-forced-alignment/diff.rst b/docs/source/fst-based-forced-alignment/diff.rst new file mode 100644 index 000000000..56b6c430e --- /dev/null +++ b/docs/source/fst-based-forced-alignment/diff.rst @@ -0,0 +1,41 @@ +Two approaches +============== + +Two approaches for FST-based forced alignment will be described: + + - `Kaldi`_-based + - `k2`_-based + +Note that the `Kaldi`_-based approach does not depend on `Kaldi`_ at all. +That is, you don't need to install `Kaldi`_ in order to use it. Instead, +we use `kaldi-decoder`_, which has ported the C++ decoding code from `Kaldi`_ +without depending on it. + +Differences between the two approaches +-------------------------------------- + +The following table compares the differences between the two approaches. + +.. list-table:: + + * - Features + - `Kaldi`_-based + - `k2`_-based + * - Support CUDA + - No + - Yes + * - Support CPU + - Yes + - Yes + * - Support batch processing + - No + - Yes on CUDA; No on CPU + * - Support streaming models + - Yes + - No + * - Support C++ APIs + - Yes + - Yes + * - Support Python APIs + - Yes + - Yes diff --git a/docs/source/fst-based-forced-alignment/index.rst b/docs/source/fst-based-forced-alignment/index.rst new file mode 100644 index 000000000..92a05faaa --- /dev/null +++ b/docs/source/fst-based-forced-alignment/index.rst @@ -0,0 +1,18 @@ +FST-based forced alignment +========================== + +This section describes how to perform **FST-based** ``forced alignment`` with models +trained by `CTC`_ loss. + +We use `CTC FORCED ALIGNMENT API TUTORIAL `_ +from `torchaudio`_ as a reference in this section. + +Different from `torchaudio`_, we use an ``FST``-based approach. + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + diff + kaldi-based + k2-based diff --git a/docs/source/fst-based-forced-alignment/k2-based.rst b/docs/source/fst-based-forced-alignment/k2-based.rst new file mode 100644 index 000000000..373e49f3e --- /dev/null +++ b/docs/source/fst-based-forced-alignment/k2-based.rst @@ -0,0 +1,4 @@ +k2-based forced alignment +========================= + +TODO(fangjun) diff --git a/docs/source/fst-based-forced-alignment/kaldi-based.rst b/docs/source/fst-based-forced-alignment/kaldi-based.rst new file mode 100644 index 000000000..69b6a665b --- /dev/null +++ b/docs/source/fst-based-forced-alignment/kaldi-based.rst @@ -0,0 +1,712 @@ +Kaldi-based forced alignment +============================ + +This section describes in detail how to use `kaldi-decoder`_ +for **FST-based** ``forced alignment`` with models trained by `CTC`_ loss. + +.. hint:: + + We have a colab notebook walking you through this section step by step. + + |kaldi-based forced alignment colab notebook| + + .. |kaldi-based forced alignment colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://github.com/k2-fsa/colab/blob/master/icefall/ctc_forced_alignment_fst_based_kaldi.ipynb + +Prepare the environment +----------------------- + +Before you continue, make sure you have setup `icefall`_ by following :ref:`install icefall`. + +.. hint:: + + You don't need to install `Kaldi`_. We will ``NOT`` use `Kaldi`_ below. + +Get the test data +----------------- + +We use the test wave +from `CTC FORCED ALIGNMENT API TUTORIAL `_ + +.. code-block:: python3 + + import torchaudio + + # Download test wave + speech_file = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") + print(speech_file) + waveform, sr = torchaudio.load(speech_file) + transcript = "i had that curiosity beside me at this moment".split() + print(waveform.shape, sr) + + assert waveform.ndim == 2 + assert waveform.shape[0] == 1 + assert sr == 16000 + +The test wave is downloaded to:: + + $HOME/.cache/torch/hub/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav + +.. raw:: html + + + + + + + + + + + + +
Wave filenameContentText
Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav + + + i had that curiosity beside me at this moment +
+ +We use the test model +from `CTC FORCED ALIGNMENT API TUTORIAL `_ + +.. code-block:: python3 + + import torch + + bundle = torchaudio.pipelines.MMS_FA + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = bundle.get_model(with_star=False).to(device) + +The model is downloaded to:: + + $HOME/.cache/torch/hub/checkpoints/model.pt + +Compute log_probs +----------------- + +.. code-block:: bash + + with torch.inference_mode(): + emission, _ = model(waveform.to(device)) + print(emission.shape) + +It should print:: + + torch.Size([1, 169, 28]) + +Create token2id and id2token +---------------------------- + +.. code-block:: python3 + + token2id = bundle.get_dict(star=None) + id2token = {i:t for t, i in token2id.items()} + token2id[""] = 0 + del token2id["-"] + +Create word2id and id2word +-------------------------- + +.. code-block:: python3 + + words = list(set(transcript)) + word2id = dict() + word2id['eps'] = 0 + for i, w in enumerate(words): + word2id[w] = i + 1 + + id2word = {i:w for w, i in word2id.items()} + +Note that we only use words from the transcript of the test wave. + +Generate lexicon-related files +------------------------------ + +We use the code below to generate the following 4 files: + + - ``lexicon.txt`` + - ``tokens.txt`` + - ``words.txt`` + - ``lexicon_disambig.txt`` + +.. caution:: + + ``words.txt`` contains only words from the transcript of the test wave. + +.. code-block:: python3 + + from prepare_lang import add_disambig_symbols + + lexicon = [(w, list(w)) for w in word2id if w != "eps"] + lexicon_disambig, max_disambig_id = add_disambig_symbols(lexicon) + + with open('lexicon.txt', 'w', encoding='utf-8') as f: + for w, tokens in lexicon: + f.write(f"{w} {' '.join(tokens)}\n") + + with open('lexicon_disambig.txt', 'w', encoding='utf-8') as f: + for w, tokens in lexicon_disambig: + f.write(f"{w} {' '.join(tokens)}\n") + + with open('tokens.txt', 'w', encoding='utf-8') as f: + for t, i in token2id.items(): + if t == '-': + t = "" + f.write(f"{t} {i}\n") + + for k in range(max_disambig_id + 2): + f.write(f"#{k} {len(token2id) + k}\n") + + with open('words.txt', 'w', encoding='utf-8') as f: + for w, i in word2id.items(): + f.write(f"{w} {i}\n") + f.write(f'#0 {len(word2id)}\n') + + +To give you an idea about what the generated files look like:: + + head -n 50 lexicon.txt lexicon_disambig.txt tokens.txt words.txt + +prints:: + + ==> lexicon.txt <== + moment m o m e n t + beside b e s i d e + i i + this t h i s + curiosity c u r i o s i t y + had h a d + that t h a t + at a t + me m e + + ==> lexicon_disambig.txt <== + moment m o m e n t + beside b e s i d e + i i + this t h i s + curiosity c u r i o s i t y + had h a d + that t h a t + at a t + me m e + + ==> tokens.txt <== + a 1 + i 2 + e 3 + n 4 + o 5 + u 6 + t 7 + s 8 + r 9 + m 10 + k 11 + l 12 + d 13 + g 14 + h 15 + y 16 + b 17 + p 18 + w 19 + c 20 + v 21 + j 22 + z 23 + f 24 + ' 25 + q 26 + x 27 + 0 + #0 28 + #1 29 + + ==> words.txt <== + eps 0 + moment 1 + beside 2 + i 3 + this 4 + curiosity 5 + had 6 + that 7 + at 8 + me 9 + #0 10 + +.. note:: + + This test model uses characters as modeling unit. If you use other types of + modeling unit, the same code can be used without any change. + +Convert transcript to an FST graph +---------------------------------- + +.. code-block:: bash + + egs/librispeech/ASR/local/prepare_lang_fst.py --lang-dir ./ + +The above command should generate two files ``H.fst`` and ``HL.fst``. We will +use ``HL.fst`` below:: + + -rw-r--r-- 1 root root 13K Jun 12 08:28 H.fst + -rw-r--r-- 1 root root 3.7K Jun 12 08:28 HL.fst + +Force aligner +------------- + +Now, everything is ready. We can use the following code to get forced alignments. + +.. code-block:: python3 + + from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions + import kaldifst + + def force_align(): + HL = kaldifst.StdVectorFst.read("./HL.fst") + decodable = DecodableCtc(emission[0].contiguous().cpu().numpy()) + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + if not decoder.reached_final(): + print(f"failed to decode xxx") + return None + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for xxx") + return None + + # We need to use i-1 here since we have incremented tokens during + # HL construction + alignment = [i-1 for i in isymbols_out] + return alignment + + alignment = force_align() + + for i, a in enumerate(alignment): + print(i, id2token[a]) + +The output should be identical to +``_. + +For ease of reference, we list the output below:: + + 0 - + 1 - + 2 - + 3 - + 4 - + 5 - + 6 - + 7 - + 8 - + 9 - + 10 - + 11 - + 12 - + 13 - + 14 - + 15 - + 16 - + 17 - + 18 - + 19 - + 20 - + 21 - + 22 - + 23 - + 24 - + 25 - + 26 - + 27 - + 28 - + 29 - + 30 - + 31 - + 32 i + 33 - + 34 - + 35 h + 36 h + 37 a + 38 - + 39 - + 40 - + 41 d + 42 - + 43 - + 44 t + 45 h + 46 - + 47 a + 48 - + 49 - + 50 t + 51 - + 52 - + 53 - + 54 c + 55 - + 56 - + 57 - + 58 u + 59 u + 60 - + 61 - + 62 - + 63 r + 64 - + 65 i + 66 - + 67 - + 68 - + 69 - + 70 - + 71 - + 72 o + 73 - + 74 - + 75 - + 76 - + 77 - + 78 - + 79 s + 80 - + 81 - + 82 - + 83 i + 84 - + 85 t + 86 - + 87 - + 88 y + 89 - + 90 - + 91 - + 92 - + 93 b + 94 - + 95 e + 96 - + 97 - + 98 - + 99 - + 100 - + 101 s + 102 - + 103 - + 104 - + 105 - + 106 - + 107 - + 108 - + 109 - + 110 i + 111 - + 112 - + 113 d + 114 e + 115 - + 116 m + 117 - + 118 - + 119 e + 120 - + 121 - + 122 - + 123 - + 124 a + 125 - + 126 - + 127 t + 128 - + 129 t + 130 h + 131 - + 132 i + 133 - + 134 - + 135 - + 136 s + 137 - + 138 - + 139 - + 140 - + 141 m + 142 - + 143 - + 144 o + 145 - + 146 - + 147 - + 148 m + 149 - + 150 - + 151 e + 152 - + 153 n + 154 - + 155 t + 156 - + 157 - + 158 - + 159 - + 160 - + 161 - + 162 - + 163 - + 164 - + 165 - + 166 - + 167 - + 168 - + +To merge tokens, we use:: + + from icefall.ctc import merge_tokens + token_spans = merge_tokens(alignment) + for span in token_spans: + print(id2token[span.token], span.start, span.end) + +The output is given below:: + + i 32 33 + h 35 37 + a 37 38 + d 41 42 + t 44 45 + h 45 46 + a 47 48 + t 50 51 + c 54 55 + u 58 60 + r 63 64 + i 65 66 + o 72 73 + s 79 80 + i 83 84 + t 85 86 + y 88 89 + b 93 94 + e 95 96 + s 101 102 + i 110 111 + d 113 114 + e 114 115 + m 116 117 + e 119 120 + a 124 125 + t 127 128 + t 129 130 + h 130 131 + i 132 133 + s 136 137 + m 141 142 + o 144 145 + m 148 149 + e 151 152 + n 153 154 + t 155 156 + +All of the code below is copied and modified +from ``_. + +Segment each word using the computed alignments +----------------------------------------------- + +.. code-block:: python3 + + def unflatten(list_, lengths): + assert len(list_) == sum(lengths) + i = 0 + ret = [] + for l in lengths: + ret.append(list_[i : i + l]) + i += l + return ret + + + word_spans = unflatten(token_spans, [len(word) for word in transcript]) + print(word_spans) + +The output is:: + + [[TokenSpan(token=2, start=32, end=33)], + [TokenSpan(token=15, start=35, end=37), TokenSpan(token=1, start=37, end=38), TokenSpan(token=13, start=41, end=42)], + [TokenSpan(token=7, start=44, end=45), TokenSpan(token=15, start=45, end=46), TokenSpan(token=1, start=47, end=48), TokenSpan(token=7, start=50, end=51)], + [TokenSpan(token=20, start=54, end=55), TokenSpan(token=6, start=58, end=60), TokenSpan(token=9, start=63, end=64), TokenSpan(token=2, start=65, end=66), TokenSpan(token=5, start=72, end=73), TokenSpan(token=8, start=79, end=80), TokenSpan(token=2, start=83, end=84), TokenSpan(token=7, start=85, end=86), TokenSpan(token=16, start=88, end=89)], + [TokenSpan(token=17, start=93, end=94), TokenSpan(token=3, start=95, end=96), TokenSpan(token=8, start=101, end=102), TokenSpan(token=2, start=110, end=111), TokenSpan(token=13, start=113, end=114), TokenSpan(token=3, start=114, end=115)], + [TokenSpan(token=10, start=116, end=117), TokenSpan(token=3, start=119, end=120)], + [TokenSpan(token=1, start=124, end=125), TokenSpan(token=7, start=127, end=128)], + [TokenSpan(token=7, start=129, end=130), TokenSpan(token=15, start=130, end=131), TokenSpan(token=2, start=132, end=133), TokenSpan(token=8, start=136, end=137)], + [TokenSpan(token=10, start=141, end=142), TokenSpan(token=5, start=144, end=145), TokenSpan(token=10, start=148, end=149), TokenSpan(token=3, start=151, end=152), TokenSpan(token=4, start=153, end=154), TokenSpan(token=7, start=155, end=156)] + ] + + +.. code-block:: python3 + + def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate): + ratio = waveform.size(1) / num_frames + x0 = int(ratio * spans[0].start) + x1 = int(ratio * spans[-1].end) + print(f"{transcript} {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") + segment = waveform[:, x0:x1] + return IPython.display.Audio(segment.numpy(), rate=sample_rate) + num_frames = emission.size(1) + +.. code-block:: python3 + + preview_word(waveform, word_spans[0], num_frames, transcript[0]) + preview_word(waveform, word_spans[1], num_frames, transcript[1]) + preview_word(waveform, word_spans[2], num_frames, transcript[2]) + preview_word(waveform, word_spans[3], num_frames, transcript[3]) + preview_word(waveform, word_spans[4], num_frames, transcript[4]) + preview_word(waveform, word_spans[5], num_frames, transcript[5]) + preview_word(waveform, word_spans[6], num_frames, transcript[6]) + preview_word(waveform, word_spans[7], num_frames, transcript[7]) + preview_word(waveform, word_spans[8], num_frames, transcript[8]) + +The segmented wave of each word along with its time stamp is given below: + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
WordTimeWave
i0.644 - 0.664 sec + +
had0.704 - 0.845 sec + +
that0.885 - 1.026 sec + +
curiosity1.086 - 1.790 sec + +
beside1.871 - 2.314 sec + +
me2.334 - 2.414 sec + +
at2.495 - 2.575 sec + +
this2.595 - 2.756 sec + +
moment2.837 - 3.138 sec + +
+ +We repost the whole wave below for ease of reference: + +.. raw:: html + + + + + + + + + + + + +
Wave filenameContentText
Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav + + + i had that curiosity beside me at this moment +
+ +Summary +------- + +Congratulations! You have succeeded in using the FST-based approach to +compute alignment of a test wave. diff --git a/docs/source/index.rst b/docs/source/index.rst index fb539d3f2..d46a4038f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -25,7 +25,7 @@ speech recognition recipes using `k2 `_. docker/index faqs model-export/index - + fst-based-forced-alignment/index .. toctree:: :maxdepth: 3 @@ -40,5 +40,5 @@ speech recognition recipes using `k2 `_. .. toctree:: :maxdepth: 2 - + decoding-with-langugage-models/index diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst index 93392aee7..4cdc25ee6 100644 --- a/docs/source/model-export/export-ncnn-conv-emformer.rst +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -15,8 +15,8 @@ We will show you step by step how to export it to `ncnn`_ and run it with `sherp .. caution:: - Please use a more recent version of PyTorch. For instance, ``torch 1.8`` - may ``not`` work. + ``torch > 2.0`` may not work. If you get errors while building pnnx, please switch + to ``torch < 2.0``. 1. Download the pre-trained model --------------------------------- diff --git a/docs/source/model-export/export-ncnn-lstm.rst b/docs/source/model-export/export-ncnn-lstm.rst index 310c3d8e4..ccf522dec 100644 --- a/docs/source/model-export/export-ncnn-lstm.rst +++ b/docs/source/model-export/export-ncnn-lstm.rst @@ -15,8 +15,8 @@ We will show you step by step how to export it to `ncnn`_ and run it with `sherp .. caution:: - Please use a more recent version of PyTorch. For instance, ``torch 1.8`` - may ``not`` work. + ``torch > 2.0`` may not work. If you get errors while building pnnx, please switch + to ``torch < 2.0``. 1. Download the pre-trained model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst index a5845b0e4..51fc6c8e5 100644 --- a/docs/source/model-export/export-ncnn-zipformer.rst +++ b/docs/source/model-export/export-ncnn-zipformer.rst @@ -15,8 +15,8 @@ We will show you step by step how to export it to `ncnn`_ and run it with `sherp .. caution:: - Please use a more recent version of PyTorch. For instance, ``torch 1.8`` - may ``not`` work. + ``torch > 2.0`` may not work. If you get errors while building pnnx, please switch + to ``torch < 2.0``. 1. Download the pre-trained model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 13f55d073513b3beaefdf0b7e16237b35199ca04 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 12 Jun 2024 17:45:13 +0800 Subject: [PATCH 172/216] Add merge_tokens for ctc forced alignment (#1649) --- icefall/ctc/__init__.py | 1 + icefall/ctc/test_utils.py | 87 +++++++++++++++++++++++++++++++++++++++ icefall/ctc/utils.py | 52 +++++++++++++++++++++++ 3 files changed, 140 insertions(+) create mode 100755 icefall/ctc/test_utils.py create mode 100644 icefall/ctc/utils.py diff --git a/icefall/ctc/__init__.py b/icefall/ctc/__init__.py index b546b31af..eb1ec47d1 100644 --- a/icefall/ctc/__init__.py +++ b/icefall/ctc/__init__.py @@ -4,3 +4,4 @@ from .prepare_lang import ( make_lexicon_fst_with_silence, ) from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo +from .utils import merge_tokens diff --git a/icefall/ctc/test_utils.py b/icefall/ctc/test_utils.py new file mode 100755 index 000000000..6fa883dfb --- /dev/null +++ b/icefall/ctc/test_utils.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +from typing import List + +from utils import TokenSpan, merge_tokens + + +def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]: + """Compute start and end frames of each token from the given alignment. + + Args: + alignment: + A list of token IDs. + blank_id: + ID of the blank. + Returns: + Return a list of TokenSpan. + """ + ans = [] + last_token = None + last_i = None + + # import pdb + + # pdb.set_trace() + for i, token in enumerate(alignment): + if token == blank: + if last_token is None or last_token == token: + continue + + # end of the last token + span = TokenSpan(token=last_token, start=last_i, end=i) + ans.append(span) + last_token = None + last_i = None + continue + + # The current token is not a blank + if last_token is None or last_token == blank: + last_token = token + last_i = i + continue + + if last_token == token: + continue + + # end of the last token and start of the current token + span = TokenSpan(token=last_token, start=last_i, end=i) + last_token = token + last_i = i + ans.append(span) + + if last_token is not None: + assert last_i is not None, (last_i, last_token) + span = TokenSpan(token=last_token, start=last_i, end=len(alignment)) + # Note for the last token, its end is larger than len(alignment)-1 + ans.append(span) + + return ans + + +def test_merge_tokens(): + data_list = [ + # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 + [0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0], + [0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3], + [1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0], + [1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3], + [0, 1, 2, 3, 0], + [1, 2, 3, 0], + [0, 1, 2, 3], + [1, 2, 3], + ] + + for data in data_list: + span1 = merge_tokens(data) + span2 = inefficient_merge_tokens(data) + assert span1 == span2, (data, span1, span2) + + +def main(): + test_merge_tokens() + + +if __name__ == "__main__": + main() diff --git a/icefall/ctc/utils.py b/icefall/ctc/utils.py new file mode 100644 index 000000000..ad49b5ffd --- /dev/null +++ b/icefall/ctc/utils.py @@ -0,0 +1,52 @@ +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +from dataclasses import dataclass +from typing import List + +import torch + + +@dataclass +class TokenSpan: + # ID of the token + token: int + + # Start frame of this token in the output log_prob + start: int + + # End frame of this token in the output log_prob + end: int + + +# See also +# https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96 +# We use torchaudio as a reference while implementing this function +def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]: + """Compute start and end frames of each token from the given alignment. + + Args: + alignment: + A list of token IDs. + blank_id: + ID of the blank. + Returns: + Return a list of TokenSpan. + """ + alignment_tensor = torch.tensor(alignment, dtype=torch.int32) + + diff = torch.diff( + alignment_tensor, + prepend=torch.tensor([-1]), + append=torch.tensor([-1]), + ) + + non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist() + + ans = [] + for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]): + token = alignment[start] + if token == blank: + continue + span = TokenSpan(token=token, start=start, end=end) + ans.append(span) + return ans From d5be739639063913a926c28895526e9ee92d0683 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 13 Jun 2024 00:20:04 +0800 Subject: [PATCH 173/216] add distill whisper results (#1648) --- egs/multi_zh-hans/ASR/RESULTS.md | 9 +++++---- egs/speechio/ASR/RESULTS.md | 11 +++++++---- egs/speechio/ASR/whisper/decode.py | 12 +++++++++++- .../whisper/whisper_decoder_forward_monkey_patch.py | 1 + 4 files changed, 24 insertions(+), 9 deletions(-) create mode 120000 egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index a7f3bc4f7..e411e80a3 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -6,10 +6,11 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search. -| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | -|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| -| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | -| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 | +|Model| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | +|-|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| +| | Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | +|whisper-large-v2-ft |Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 | +|whisper-large-v2-ft-distill |Greedy Search | 24.91 | 26.73 | 0.91 | 0.94 | 2.71 | 2.98 | 17.65 | 2.81 | 2.47 | 5.16 | 2.10 | 6.27 | 8.34 | Command for training is: ```bash diff --git a/egs/speechio/ASR/RESULTS.md b/egs/speechio/ASR/RESULTS.md index f1273d41e..3c556f74e 100644 --- a/egs/speechio/ASR/RESULTS.md +++ b/egs/speechio/ASR/RESULTS.md @@ -17,12 +17,15 @@ | 7 | aispeech_api_zh | 3.62% | 2023.12 | | 8 | **whisper-large-ft-v1** | **4.32%** | 2024.04 | | 9 | **whisper-large-ft-v0.5** | **4.60%** | 2024.04 | -| 10 | **zipformer (70Mb)** | **6.17%** | 2023.10 | -| 11 | **whisper-large-ft-v0** | **6.34%** | 2023.03 | -| 12 | baidu_pro_api_zh | 7.29% | 2023.12 | +| 10 | **whisper-large-ft-v1-distill** | **4.71%** | 2024.04 | +| 11 | **zipformer (70Mb)** | **6.17%** | 2023.10 | +| 12 | **whisper-large-ft-v0** | **6.34%** | 2023.03 | +| 13 | baidu_pro_api_zh | 7.29% | 2023.12 | Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67) +For **whisper-large-ft-v1-distill**, instead of actually using distillation loss for training, the model structure and parameter initialization method from the [distill-whisper](https://arxiv.org/abs/2311.00430) paper were adopted: only the first and last layers of the decoder were retained. +

Detail all models

| Model | Training Set | Note | @@ -31,7 +34,7 @@ Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leade |[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/tree/main/exp_large_v2)| wenetspeech | greedy_search, 3 epochs| |[whisper-large-ft-v0.5](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/blob/main/epoch-2-avg-5.pt)| wenetspeech(updated) | [wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy_search, 2 epochs | |[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 3 epochs| - +|[whisper-large-ft-v1-distill](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1-distill)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 6 epochs|

diff --git a/egs/speechio/ASR/whisper/decode.py b/egs/speechio/ASR/whisper/decode.py index 70f743eee..c20f1f714 100644 --- a/egs/speechio/ASR/whisper/decode.py +++ b/egs/speechio/ASR/whisper/decode.py @@ -58,6 +58,7 @@ from lhotse.cut import Cut from multi_dataset import MultiDataset from tn.chinese.normalizer import Normalizer from whisper.normalizers import BasicTextNormalizer +from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from zhconv import convert @@ -215,7 +216,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], + choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], help="""The model name to use. """, ) @@ -227,6 +228,13 @@ def get_parser(): help="replace whisper encoder forward method to remove input length restriction", ) + parser.add_argument( + "--use-distill-whisper", + type=str2bool, + default=False, + help="Whether to use architecture of distill whisper.", + ) + return parser @@ -431,6 +439,8 @@ def main(): if params.remove_whisper_encoder_input_length_restriction: replace_whisper_encoder_forward() + if params.use_distill_whisper: + replace_whisper_decoder_forward() model = whisper.load_model(params.model_name, "cpu") if params.epoch > 0: if params.avg > 1: diff --git a/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py b/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py new file mode 120000 index 000000000..167fba1eb --- /dev/null +++ b/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py @@ -0,0 +1 @@ +../../../multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py \ No newline at end of file From 3b40d9bbb12f045dd951d6608efbdfdbe3652ff8 Mon Sep 17 00:00:00 2001 From: Triplecq Date: Thu, 13 Jun 2024 02:19:03 -0400 Subject: [PATCH 174/216] Zipformer recipe for ReazonSpeech (#1611) * Add first cut at ReazonSpeech recipe This recipe is mostly based on egs/csj, but tweaked to the point that can be run with ReazonSpeech corpus. Signed-off-by: Fujimoto Seiji --------- Signed-off-by: Fujimoto Seiji Co-authored-by: Fujimoto Seiji Co-authored-by: Chen Co-authored-by: root --- egs/reazonspeech/ASR/README.md | 29 + egs/reazonspeech/ASR/RESULTS.md | 49 + .../ASR/local/compute_fbank_reazonspeech.py | 146 ++ .../ASR/local/display_manifest_statistics.py | 58 + .../ASR/local/prepare_lang_char.py | 75 + .../ASR/local/utils/asr_datamodule.py | 355 +++++ egs/reazonspeech/ASR/local/utils/tokenizer.py | 253 +++ .../ASR/local/validate_manifest.py | 96 ++ egs/reazonspeech/ASR/prepare.sh | 86 + egs/reazonspeech/ASR/shared | 1 + .../ASR/zipformer/asr_datamodule.py | 1 + egs/reazonspeech/ASR/zipformer/beam_search.py | 1 + egs/reazonspeech/ASR/zipformer/ctc_decode.py | 1 + egs/reazonspeech/ASR/zipformer/decode.py | 1076 +++++++++++++ .../ASR/zipformer/decode_stream.py | 1 + egs/reazonspeech/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/do_not_use_it_directly.py | 1261 +++++++++++++++ .../ASR/zipformer/encoder_interface.py | 1 + egs/reazonspeech/ASR/zipformer/export-onnx.py | 1 + egs/reazonspeech/ASR/zipformer/export.py | 1 + .../ASR/zipformer/generate_averaged_model.py | 1 + egs/reazonspeech/ASR/zipformer/joiner.py | 1 + egs/reazonspeech/ASR/zipformer/model.py | 1 + egs/reazonspeech/ASR/zipformer/my_profile.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/reazonspeech/ASR/zipformer/optim.py | 1 + egs/reazonspeech/ASR/zipformer/pretrained.py | 1 + egs/reazonspeech/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 597 +++++++ egs/reazonspeech/ASR/zipformer/subsampling.py | 1 + .../ASR/zipformer/test_scaling.py | 1 + .../ASR/zipformer/test_subsampling.py | 1 + egs/reazonspeech/ASR/zipformer/tokenizer.py | 1 + egs/reazonspeech/ASR/zipformer/train.py | 1383 +++++++++++++++++ egs/reazonspeech/ASR/zipformer/zipformer.py | 1 + 37 files changed, 5488 insertions(+) create mode 100644 egs/reazonspeech/ASR/README.md create mode 100644 egs/reazonspeech/ASR/RESULTS.md create mode 100644 egs/reazonspeech/ASR/local/compute_fbank_reazonspeech.py create mode 100644 egs/reazonspeech/ASR/local/display_manifest_statistics.py create mode 100644 egs/reazonspeech/ASR/local/prepare_lang_char.py create mode 100644 egs/reazonspeech/ASR/local/utils/asr_datamodule.py create mode 100644 egs/reazonspeech/ASR/local/utils/tokenizer.py create mode 100644 egs/reazonspeech/ASR/local/validate_manifest.py create mode 100755 egs/reazonspeech/ASR/prepare.sh create mode 120000 egs/reazonspeech/ASR/shared create mode 120000 egs/reazonspeech/ASR/zipformer/asr_datamodule.py create mode 120000 egs/reazonspeech/ASR/zipformer/beam_search.py create mode 120000 egs/reazonspeech/ASR/zipformer/ctc_decode.py create mode 100755 egs/reazonspeech/ASR/zipformer/decode.py create mode 120000 egs/reazonspeech/ASR/zipformer/decode_stream.py create mode 120000 egs/reazonspeech/ASR/zipformer/decoder.py create mode 100755 egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py create mode 120000 egs/reazonspeech/ASR/zipformer/encoder_interface.py create mode 120000 egs/reazonspeech/ASR/zipformer/export-onnx.py create mode 120000 egs/reazonspeech/ASR/zipformer/export.py create mode 120000 egs/reazonspeech/ASR/zipformer/generate_averaged_model.py create mode 120000 egs/reazonspeech/ASR/zipformer/joiner.py create mode 120000 egs/reazonspeech/ASR/zipformer/model.py create mode 120000 egs/reazonspeech/ASR/zipformer/my_profile.py create mode 120000 egs/reazonspeech/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/reazonspeech/ASR/zipformer/optim.py create mode 120000 egs/reazonspeech/ASR/zipformer/pretrained.py create mode 120000 egs/reazonspeech/ASR/zipformer/scaling.py create mode 120000 egs/reazonspeech/ASR/zipformer/scaling_converter.py create mode 120000 egs/reazonspeech/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/reazonspeech/ASR/zipformer/streaming_decode.py create mode 120000 egs/reazonspeech/ASR/zipformer/subsampling.py create mode 120000 egs/reazonspeech/ASR/zipformer/test_scaling.py create mode 120000 egs/reazonspeech/ASR/zipformer/test_subsampling.py create mode 120000 egs/reazonspeech/ASR/zipformer/tokenizer.py create mode 100755 egs/reazonspeech/ASR/zipformer/train.py create mode 120000 egs/reazonspeech/ASR/zipformer/zipformer.py diff --git a/egs/reazonspeech/ASR/README.md b/egs/reazonspeech/ASR/README.md new file mode 100644 index 000000000..ad5c15de3 --- /dev/null +++ b/egs/reazonspeech/ASR/README.md @@ -0,0 +1,29 @@ +# Introduction + + + +**ReazonSpeech** is an open-source dataset that contains a diverse set of natural Japanese speech, collected from terrestrial television streams. It contains more than 35,000 hours of audio. + + + +The dataset is available on Hugging Face. For more details, please visit: + +- Dataset: https://huggingface.co/datasets/reazon-research/reazonspeech +- Paper: https://research.reazon.jp/_static/reazonspeech_nlp2023.pdf + + + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + + + +There are various folders containing the name `transducer` in this folder. The following table lists the differences among them. + +| | Encoder | Decoder | Comment | +| ---------------------------------------- | -------------------- | ------------------ | ------------------------------------------------- | +| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe | + +The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. + diff --git a/egs/reazonspeech/ASR/RESULTS.md b/egs/reazonspeech/ASR/RESULTS.md new file mode 100644 index 000000000..c0b4fe54a --- /dev/null +++ b/egs/reazonspeech/ASR/RESULTS.md @@ -0,0 +1,49 @@ +## Results + +### Zipformer + +#### Non-streaming + +##### large-scaled model, number of model parameters: 159337842, i.e., 159.34 M + +| decoding method | In-Distribution CER | JSUT | CommonVoice | TEDx | comment | +| :------------------: | :-----------------: | :--: | :---------: | :---: | :----------------: | +| greedy search | 4.2 | 6.7 | 7.84 | 17.9 | --epoch 39 --avg 7 | +| modified beam search | 4.13 | 6.77 | 7.69 | 17.82 | --epoch 39 --avg 7 | + +The training command is: + +```shell +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 40 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --causal 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --lang data/lang_char \ + --max-duration 1600 +``` + +The decoding command is: + +```shell +./zipformer/decode.py \ + --epoch 40 \ + --avg 16 \ + --exp-dir zipformer/exp-large \ + --max-duration 600 \ + --causal 0 \ + --decoding-method greedy_search \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --lang data/lang_char \ + --blank-penalty 0 +``` + diff --git a/egs/reazonspeech/ASR/local/compute_fbank_reazonspeech.py b/egs/reazonspeech/ASR/local/compute_fbank_reazonspeech.py new file mode 100644 index 000000000..af7841406 --- /dev/null +++ b/egs/reazonspeech/ASR/local/compute_fbank_reazonspeech.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import os +from pathlib import Path +from typing import List, Tuple + +import torch + +# fmt: off +from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527 + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + RecordingSet, + SupervisionSet, +) + +# fmt: on + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +RNG_SEED = 42 +concat_params = {"gap": 1.0, "maxlen": 10.0} + + +def make_cutset_blueprints( + manifest_dir: Path, +) -> List[Tuple[str, CutSet]]: + cut_sets = [] + + # Create test dataset + logging.info("Creating test cuts.") + cut_sets.append( + ( + "test", + CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "reazonspeech_recordings_test.jsonl.gz" + ), + supervisions=SupervisionSet.from_file( + manifest_dir / "reazonspeech_supervisions_test.jsonl.gz" + ), + ), + ) + ) + + # Create dev dataset + logging.info("Creating dev cuts.") + cut_sets.append( + ( + "dev", + CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "reazonspeech_recordings_dev.jsonl.gz" + ), + supervisions=SupervisionSet.from_file( + manifest_dir / "reazonspeech_supervisions_dev.jsonl.gz" + ), + ), + ) + ) + + # Create train dataset + logging.info("Creating train cuts.") + cut_sets.append( + ( + "train", + CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "reazonspeech_recordings_train.jsonl.gz" + ), + supervisions=SupervisionSet.from_file( + manifest_dir / "reazonspeech_supervisions_train.jsonl.gz" + ), + ), + ) + ) + return cut_sets + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", "--manifest-dir", type=Path) + return parser.parse_args() + + +def main(): + args = get_args() + + extractor = Fbank(FbankConfig(num_mel_bins=80)) + num_jobs = min(16, os.cpu_count()) + + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + if (args.manifest_dir / ".reazonspeech-fbank.done").exists(): + logging.info( + "Previous fbank computed for ReazonSpeech found. " + f"Delete {args.manifest_dir / '.reazonspeech-fbank.done'} to allow recomputing fbank." + ) + return + else: + cut_sets = make_cutset_blueprints(args.manifest_dir) + for part, cut_set in cut_sets: + logging.info(f"Processing {part}") + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + num_jobs=num_jobs, + storage_path=(args.manifest_dir / f"feats_{part}").as_posix(), + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz") + + logging.info("All fbank computed for ReazonSpeech.") + (args.manifest_dir / ".reazonspeech-fbank.done").touch() + + +if __name__ == "__main__": + main() diff --git a/egs/reazonspeech/ASR/local/display_manifest_statistics.py b/egs/reazonspeech/ASR/local/display_manifest_statistics.py new file mode 100644 index 000000000..ace1dd73f --- /dev/null +++ b/egs/reazonspeech/ASR/local/display_manifest_statistics.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 The University of Electro-Communications (author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from lhotse import CutSet, load_manifest + +ARGPARSE_DESCRIPTION = """ +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in +pruned_transducer_stateless5/train.py for usage. +""" + + +def get_parser(): + parser = argparse.ArgumentParser( + description=ARGPARSE_DESCRIPTION, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") + + return parser.parse_args() + + +def main(): + args = get_parser() + + for part in ["train", "dev"]: + path = args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz" + cuts: CutSet = load_manifest(path) + + print("\n---------------------------------\n") + print(path.name + ":") + cuts.describe() + + +if __name__ == "__main__": + main() diff --git a/egs/reazonspeech/ASR/local/prepare_lang_char.py b/egs/reazonspeech/ASR/local/prepare_lang_char.py new file mode 100644 index 000000000..19c5f4a31 --- /dev/null +++ b/egs/reazonspeech/ASR/local/prepare_lang_char.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=Path("data/lang_char"), + help=( + "Name of lang dir. " + "If not set, this will default to lang_char_{trans-mode}" + ), + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.basicConfig( + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), + level=logging.INFO, + ) + + sysdef_string = set(["", "", "", " "]) + + token_set = set() + logging.info(f"Creating vocabulary from {args.train_cut}.") + train_cut: CutSet = CutSet.from_file(args.train_cut) + for cut in train_cut: + for sup in cut.supervisions: + token_set.update(sup.text) + + token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] + args.lang_dir.mkdir(parents=True, exist_ok=True) + (args.lang_dir / "tokens.txt").write_text( + "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) + ) + + (args.lang_dir / "lang_type").write_text("char") + logging.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/egs/reazonspeech/ASR/local/utils/asr_datamodule.py b/egs/reazonspeech/ASR/local/utils/asr_datamodule.py new file mode 100644 index 000000000..e70370760 --- /dev/null +++ b/egs/reazonspeech/ASR/local/utils/asr_datamodule.py @@ -0,0 +1,355 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class ReazonSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/dev/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=False, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders( + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + + transforms = [] + input_transforms = [] + + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/reazonspeech/ASR/local/utils/tokenizer.py b/egs/reazonspeech/ASR/local/utils/tokenizer.py new file mode 100644 index 000000000..c9be72be1 --- /dev/null +++ b/egs/reazonspeech/ASR/local/utils/tokenizer.py @@ -0,0 +1,253 @@ +import argparse +from pathlib import Path +from typing import Callable, List, Union + +import sentencepiece as spm +from k2 import SymbolTable + + +class Tokenizer: + text2word: Callable[[str], List[str]] + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Lang related options") + + group.add_argument("--lang", type=Path, help="Path to lang directory.") + + group.add_argument( + "--lang-type", + type=str, + default=None, + help=( + "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " + "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" + ), + ) + + @staticmethod + def Load(lang_dir: Path, lang_type="", oov=""): + + if not lang_type: + assert (lang_dir / "lang_type").exists(), "lang_type not specified." + lang_type = (lang_dir / "lang_type").read_text().strip() + + tokenizer = None + + if lang_type == "bpe": + assert ( + lang_dir / "bpe.model" + ).exists(), f"No BPE .model could be found in {lang_dir}." + tokenizer = spm.SentencePieceProcessor() + tokenizer.Load(str(lang_dir / "bpe.model")) + elif lang_type == "char": + tokenizer = CharTokenizer(lang_dir, oov=oov) + else: + raise NotImplementedError(f"{lang_type} not supported at the moment.") + + return tokenizer + + load = Load + + def PieceToId(self, piece: str) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + piece_to_id = PieceToId + + def IdToPiece(self, id: int) -> str: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + id_to_piece = IdToPiece + + def GetPieceSize(self) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + get_piece_size = GetPieceSize + + def __len__(self) -> int: + return self.get_piece_size() + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsIds(self, input: str) -> List[int]: + return self.EncodeAsIdsBatch([input])[0] + + def EncodeAsPieces(self, input: str) -> List[str]: + return self.EncodeAsPiecesBatch([input])[0] + + def Encode( + self, input: Union[str, List[str]], out_type=int + ) -> Union[List, List[List]]: + if not input: + return [] + + if isinstance(input, list): + if out_type is int: + return self.EncodeAsIdsBatch(input) + if out_type is str: + return self.EncodeAsPiecesBatch(input) + + if out_type is int: + return self.EncodeAsIds(input) + if out_type is str: + return self.EncodeAsPieces(input) + + encode = Encode + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodeIds(self, input: List[int]) -> str: + return self.DecodeIdsBatch([input])[0] + + def DecodePieces(self, input: List[str]) -> str: + return self.DecodePiecesBatch([input])[0] + + def Decode( + self, + input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], + ) -> Union[List[str], str]: + + if not input: + return "" + + if isinstance(input, int): + return self.id_to_piece(input) + elif isinstance(input, str): + raise TypeError( + "Unlike spm.SentencePieceProcessor, cannot decode from type str." + ) + + if isinstance(input[0], list): + if not input[0] or isinstance(input[0][0], int): + return self.DecodeIdsBatch(input) + + if isinstance(input[0][0], str): + return self.DecodePiecesBatch(input) + + if isinstance(input[0], int): + return self.DecodeIds(input) + if isinstance(input[0], str): + return self.DecodePieces(input) + + raise RuntimeError("Unknown input type") + + decode = Decode + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: + if isinstance(input, list): + return self.SplitBatch(input) + elif isinstance(input, str): + return self.SplitBatch([input])[0] + raise RuntimeError("Unknown input type") + + split = Split + + +class CharTokenizer(Tokenizer): + def __init__(self, lang_dir: Path, oov="", sep=""): + assert ( + lang_dir / "tokens.txt" + ).exists(), f"tokens.txt could not be found in {lang_dir}." + token_table = SymbolTable.from_file(lang_dir / "tokens.txt") + assert ( + "#0" not in token_table + ), "This tokenizer does not support disambig symbols." + self._id2sym = token_table._id2sym + self._sym2id = token_table._sym2id + self.oov = oov + self.oov_id = self._sym2id[oov] + self.sep = sep + if self.sep: + self.text2word = lambda x: x.split(self.sep) + else: + self.text2word = lambda x: list(x.replace(" ", "")) + + def piece_to_id(self, piece: str) -> int: + try: + return self._sym2id[piece] + except KeyError: + return self.oov_id + + def id_to_piece(self, id: int) -> str: + return self._id2sym[id] + + def get_piece_size(self) -> int: + return len(self._sym2id) + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + return [ + [i if i in self._sym2id else self.oov for i in self.text2word(text)] + for text in input + ] + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + return [self.sep.join(text) for text in input] + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + return [self.text2word(text) for text in input] + + +def test_CharTokenizer(): + test_single_string = "こんにちは" + test_multiple_string = [ + "今日はいい天気ですよね", + "諏訪湖は綺麗でしょう", + "这在词表外", + "分かち 書き に し た 文章 です", + "", + ] + test_empty_string = "" + sp = Tokenizer.load(Path("lang_char"), "char", oov="") + splitter = sp.split + print(sp.encode(test_single_string, out_type=str)) + print(sp.encode(test_single_string, out_type=int)) + print(sp.encode(test_multiple_string, out_type=str)) + print(sp.encode(test_multiple_string, out_type=int)) + print(sp.encode(test_empty_string, out_type=str)) + print(sp.encode(test_empty_string, out_type=int)) + print(sp.decode(sp.encode(test_single_string, out_type=str))) + print(sp.decode(sp.encode(test_single_string, out_type=int))) + print(sp.decode(sp.encode(test_multiple_string, out_type=str))) + print(sp.decode(sp.encode(test_multiple_string, out_type=int))) + print(sp.decode(sp.encode(test_empty_string, out_type=str))) + print(sp.decode(sp.encode(test_empty_string, out_type=int))) + print(splitter(test_single_string)) + print(splitter(test_multiple_string)) + print(splitter(test_empty_string)) + + +if __name__ == "__main__": + test_CharTokenizer() diff --git a/egs/reazonspeech/ASR/local/validate_manifest.py b/egs/reazonspeech/ASR/local/validate_manifest.py new file mode 100644 index 000000000..7f67c64b6 --- /dev/null +++ b/egs/reazonspeech/ASR/local/validate_manifest.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within cut time bounds + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + s = c.supervisions[0] + + # Removed because when the cuts were trimmed from supervisions, + # the start time of the supervision can be lesser than cut start time. + # https://github.com/lhotse-speech/lhotse/issues/813 + # if s.start < c.start: + # raise ValueError( + # f"{c.id}: Supervision start time {s.start} is less " + # f"than cut start time {c.start}" + # ) + + if s.end > c.end: + raise ValueError( + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" + ) + + +def main(): + args = get_args() + + manifest = Path(args.manifest) + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest(manifest) + assert isinstance(cut_set, CutSet) + + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/reazonspeech/ASR/prepare.sh b/egs/reazonspeech/ASR/prepare.sh new file mode 100755 index 000000000..d5e0a9491 --- /dev/null +++ b/egs/reazonspeech/ASR/prepare.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/ReazonSpeech +# You can find FLAC files in this directory. +# You can download them from https://huggingface.co/datasets/reazon-research/reazonspeech +# +# - $dl_dir/dataset.json +# The metadata of the ReazonSpeech dataset. + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "Running prepare.sh" + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/ReazonSpeech, + # you can create a symlink + # + # ln -sfv /path/to/ReazonSpeech $dl_dir/ReazonSpeech + # + if [ ! -d $dl_dir/ReazonSpeech/downloads ]; then + # Download small-v1 by default. + lhotse download reazonspeech --subset small-v1 $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare ReazonSpeech manifest" + # We assume that you have downloaded the ReazonSpeech corpus + # to $dl_dir/ReazonSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.reazonspeech.done ]; then + lhotse prepare reazonspeech -j $nj $dl_dir/ReazonSpeech data/manifests + touch data/manifests/.reazonspeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute ReazonSpeech fbank" + if [ ! -e data/manifests/.reazonspeech-validated.done ]; then + python local/compute_fbank_reazonspeech.py --manifest-dir data/manifests + python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_train.jsonl.gz + python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_dev.jsonl.gz + python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_test.jsonl.gz + touch data/manifests/.reazonspeech-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare ReazonSpeech lang_char" + python local/prepare_lang_char.py data/manifests/reazonspeech_cuts_train.jsonl.gz +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Show manifest statistics" + python local/display_manifest_statistics.py --manifest-dir data/manifests > data/manifests/manifest_statistics.txt + cat data/manifests/manifest_statistics.txt +fi \ No newline at end of file diff --git a/egs/reazonspeech/ASR/shared b/egs/reazonspeech/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/reazonspeech/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/asr_datamodule.py b/egs/reazonspeech/ASR/zipformer/asr_datamodule.py new file mode 120000 index 000000000..a48591198 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/beam_search.py b/egs/reazonspeech/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/ctc_decode.py b/egs/reazonspeech/ASR/zipformer/ctc_decode.py new file mode 120000 index 000000000..faa8bd562 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/ctc_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/decode.py b/egs/reazonspeech/ASR/zipformer/decode.py new file mode 100755 index 000000000..cdd2145f2 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/decode.py @@ -0,0 +1,1076 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import ReazonSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from tokenizer import Tokenizer +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.text2word(sp.decode(hyp))) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = sp.text2word(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and are defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + + for subdir in ["valid"]: + results_dict = decode_dataset( + dl=reazonspeech_corpus.test_dataloaders( + getattr(reazonspeech_corpus, f"{subdir}_cuts")() + ), + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + tot_err = save_results( + params=params, + test_set_name=subdir, + results_dict=results_dict, + ) + # with ( + # params.res_dir + # / ( + # f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" + # f"_{params.avg}_{params.epoch}.cer" + # ) + # ).open("w") as fout: + # if len(tot_err) == 1: + # fout.write(f"{tot_err[0][1]}") + # else: + # fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/reazonspeech/ASR/zipformer/decode_stream.py b/egs/reazonspeech/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/decoder.py b/egs/reazonspeech/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py b/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py new file mode 100755 index 000000000..072679cfc --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py @@ -0,0 +1,1261 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import ReazonSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer_for_ncnn_export_only import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 30.0: + logging.debug( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + train_cuts = reazonspeech_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = reazonspeech_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = reazonspeech_corpus.valid_cuts() + valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + raise RuntimeError("Please don't use this file directly!") + parser = get_parser() + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/reazonspeech/ASR/zipformer/encoder_interface.py b/egs/reazonspeech/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/export-onnx.py b/egs/reazonspeech/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/export.py b/egs/reazonspeech/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/generate_averaged_model.py b/egs/reazonspeech/ASR/zipformer/generate_averaged_model.py new file mode 120000 index 000000000..5a015ee6c --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/joiner.py b/egs/reazonspeech/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/model.py b/egs/reazonspeech/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/my_profile.py b/egs/reazonspeech/ASR/zipformer/my_profile.py new file mode 120000 index 000000000..3a90b2628 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/my_profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/onnx_pretrained.py b/egs/reazonspeech/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/optim.py b/egs/reazonspeech/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/pretrained.py b/egs/reazonspeech/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/scaling.py b/egs/reazonspeech/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/scaling_converter.py b/egs/reazonspeech/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/streaming_beam_search.py b/egs/reazonspeech/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..4c18c7563 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -0,0 +1,597 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --lang data/lang_char \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +from asr_datamodule import ReazonSpeechAsrDataModule +from decode import save_results +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +@torch.no_grad() +def main(): + parser = get_parser() + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if not params.res_dir: + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + args.return_cuts = True + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + + for subdir in ["valid"]: + results_dict = decode_dataset( + cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, test_set_name=subdir, results_dict=results_dict + ) + + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/reazonspeech/ASR/zipformer/subsampling.py b/egs/reazonspeech/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/test_scaling.py b/egs/reazonspeech/ASR/zipformer/test_scaling.py new file mode 120000 index 000000000..715798436 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/test_scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/test_subsampling.py b/egs/reazonspeech/ASR/zipformer/test_subsampling.py new file mode 120000 index 000000000..bf0ee3d11 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/test_subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/tokenizer.py b/egs/reazonspeech/ASR/zipformer/tokenizer.py new file mode 120000 index 000000000..958c99e85 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/tokenizer.py @@ -0,0 +1 @@ +../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/train.py b/egs/reazonspeech/ASR/zipformer/train.py new file mode 100755 index 000000000..8c6f4bb9a --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/train.py @@ -0,0 +1,1383 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import ReazonSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.015, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + train_cuts = reazonspeech_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = reazonspeech_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = reazonspeech_corpus.valid_cuts() + valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/reazonspeech/ASR/zipformer/zipformer.py b/egs/reazonspeech/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 890eeec82c221c7d8a232b0ba3b3b1c4663859ef Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Sun, 16 Jun 2024 12:14:44 +0800 Subject: [PATCH 175/216] Add qwen-audio style model training: using whisper + qwen2 (#1652) --- egs/speech_llm/ASR_LLM/README.md | 20 + egs/speech_llm/ASR_LLM/RESULTS.md | 62 ++ egs/speech_llm/ASR_LLM/assets/framework.png | Bin 0 -> 853635 bytes egs/speech_llm/ASR_LLM/prepare.sh | 46 + .../ASR_LLM/whisper_llm_zh/asr_datamodule.py | 1 + .../ASR_LLM/whisper_llm_zh/decode.py | 650 +++++++++++++ .../whisper_llm_zh/ds_config_zero1.json | 38 + .../ASR_LLM/whisper_llm_zh/model.py | 285 ++++++ .../ASR_LLM/whisper_llm_zh/multi_dataset.py | 338 +++++++ .../ASR_LLM/whisper_llm_zh/requirements.txt | 11 + .../ASR_LLM/whisper_llm_zh/train.py | 872 ++++++++++++++++++ .../whisper_encoder_forward_monkey_patch.py | 1 + 12 files changed, 2324 insertions(+) create mode 100644 egs/speech_llm/ASR_LLM/README.md create mode 100644 egs/speech_llm/ASR_LLM/RESULTS.md create mode 100644 egs/speech_llm/ASR_LLM/assets/framework.png create mode 100644 egs/speech_llm/ASR_LLM/prepare.sh create mode 120000 egs/speech_llm/ASR_LLM/whisper_llm_zh/asr_datamodule.py create mode 100755 egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py create mode 100644 egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json create mode 100644 egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py create mode 100644 egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py create mode 100644 egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt create mode 100755 egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py create mode 120000 egs/speech_llm/ASR_LLM/whisper_llm_zh/whisper_encoder_forward_monkey_patch.py diff --git a/egs/speech_llm/ASR_LLM/README.md b/egs/speech_llm/ASR_LLM/README.md new file mode 100644 index 000000000..171240db0 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/README.md @@ -0,0 +1,20 @@ + +# Introduction + +This recipe includes scripts for training [Qwen-Audio](https://github.com/QwenLM/Qwen-Audio/tree/main) style model using multiple datasets. + +
+

+ +

7Bg9eeIK2Z zn9erXV0_pRyv5BC#}gPoY?jdFGbBI4DkuRmw@=ut z2&+h+PNaLm&|g`UYStGsK!} z8X9tC(bfAM1~@`_;L%v2-61M2MyHuhf@pQ{6ym<=gPyF9`#2`)Xn64s(F%DX=t8yF zVdE5!@4EsY9CjR|!cj!6&y(Aq%@XM2w6sF7M0Kzta3Y|hO1<(0f^mf7QawU!=S2k> zeI#oid>%OwBP3!9K)r*UCjkLQ#!(Hf_YgrNQ~IdtQ3!EmJ=X&+l41NA+TnDq>x zj*sY5N}b5V6tiYw&ux+Yf?L+p{mX&XO4RdYq3(3qF%kVj*vN5ETs-AwZr#H8+e0Ac2VV0L%cDzO${dO@ zMfdStdDEihk}b&xp5MMOfQ zziw3)aQe;x70$vkSgxU<}XsUE4fFnxEp^A%rH}^a!7SJegLV_y^ z)kg~Dn4VGm-~XP2D`clznYzY5lcuBNVnFLtB!%>G`lGFKzyn`TL6DE_%BBYwzhDo~ zMljgpMoM;rR=*lc0$Uj(zV{JKC_xL%#`6 zxxc$!SM=*74>kL`F1w5o0}~`z7V~!Bi`hTffKfchyx#Z$#mOUu!~O8>OhwD_;>XR^ zMArs#`=u+pI0Ytg83nNJ-6Vq52BH|m+wI-vppQ-ySQNtnJ96vsdf{6;KzsON) ziy)c0RW^&2CjCSSsqsyglhj^Sv?7l5e5L9hN4q&Ra#JjrVTJ57W51wUu;|a;u;WOg zo{zld4Kfh~K(nreijaX7y50aQALH?C{8uH<^>|Luy0N&ywEy;$RKR9O7tnS6(fqr_ z96c+m2`H}}&3XCxp!(H_H*AO|sg7Ht08Z`hGJ>Q0Yh|vT?KpTs-$3(wOWH(c|7UXx z@7rEES!&nrbB9tHp~c4mH`Bg=VpLO^awQ6qR5HV zzkXaJX4r5!3PR$@)s0ppB_vJO(q>s{^byG#M=}+^@<*wPHk7I>B;gd`)ifVmF$VB!N11qx0zlD4PtAP0ir@$$0tVg zOpC*s9+qf>`#KQ4E zG#8WfXI4<^`@9L3DP)O5xP8YhsH(W~{Or8n4UaLsPm_ShUPk_p(UJ(0RuVTv8P1OI zeilDI%AY(Nyl=M=mUt3W9~*-spWpsLwiBnGNt;i$3nWBjZy)0%L%Iv^tTseZPG_wy zUG#ijG{tEQSPk~)z}BDFW*|GXst;^16$0LjD^shWN2W8q66V`I9yxsKLMKNYJjD3!vg_3Zmot8%D zN*MHGMny#h0uKGl$8~*<+|i|^0^V)D-j8GpQc>k2BWUULk?~~FC>nFwD%n1u#C5oZ zZX#0LwA9op&uMxap|g1?1^;>O+RH^K#nuQCmT`cox%q7{>iU=D$jkicE3EEvJx8~p z5rH+WR%3);Tx&k#Vp;;>a>TwjyU7s*9I~sQw<>GUF`I*M8IszsZkHE0*J`Y)Geq25 z?p~k%{NhNKXHAeEC1e3#oT#T4t~P6OoWMbr+FrV~t1S(L<}}SMEWD$mfWVQFh05X| zId^Zb_n`ox0i90hRC=Q`*J(V#hKfmFnO?o!LzNlh_^RK!k6=$@OWF7B!l*2O z*19Ip`iFsdlJZ!7`{0bYu__I!#dzlfj-2$}1 z!hP}ti^W)YzU`JKErdtw&H5W%`P#Yh4>@=>BrtF@qbN`H#BE*ejdjD#rj@sY`tLV7 z6SX29PfNe-4-bn45A;ZvQImbJ5wsp{zB}=JI<0qXv~Sz&qgPRpQ;lwUdwn|36ON>r z6#4lj0YV)0YSJImdb{*(QOE+CJDOblwe#uBkKY1ABNKy6&;by!yZ9Mf2xoF5 zhxj02lxek*CHTuJ8yZ2o7E4`^v1#P&7VFSMuUejmh8r3LHCsQvQ~*NGHh|?Ck<8k8 z)Ajm@ootoluqdtn@)s7fQ0s13B2Ouxk3DZX1`h>u6(!(`u|~6KaN70e##z>8u8^nU z^L797JBs%oDVl*FK1PSf>CX=`TdyL4n)zd&xW8v$1ozK7AlMrv=uK%5m6lt971}=| zXLs*Fbz0{ciGX7h*T}{V&M;#Y8`VnhvsZ&=O!@DH>&{*JSR?uWf9SWK*3SfMCe%^($QJmc{L@d>5)_~D2S~d zj@Bj|2so;*>K&h`xn!4FR4ou=ws^T&AmWmEh>JB9gX~tr=ghTi;3RK}xR0$nHL(MX z9j1v>GD^^7Em===BmQI*S0@jb(Vjv~hFI3VnG-Z&o^iCIVpW=h$DCX1F~1}Tgzur< z0kNXMWi*v0Q?Fy^VEtnb|4cLp@|eHVWU2=B@x`}~lWS_!l*O+iksBJ@xOcrOoEyYWoCc;fGe{)UCY~Fvez_GqP`uM5}!Mcq7cC7*YSJ4=x2xs9RS%bV2rcTwK>5A7@_OS55>4i~4l>aILnd7EI4lMlR`uAbB6gr!1 z-`#OPS!5zyeYZBeaf_ZL3USA+zx8Y37TDST#rb>mm#V6_9%xQ;WnuUc(C3GO`t?^N z`i8O-QOI=P1b}UP3e+jeX?af>;(t@Y7}+ID{@n5Uq9yn+Koj%4fh2x@{*w?xiA2PW z3}xc%KQz3KbB@U*roBK2Ww(nuIy{JJyFiyQ;Yi!F(t|0 zk`gn1-Wt^C_>@b79|OFlXo$Tf4kT&5K}4PdR{y=HbrlHpiu5usq!jHQfri!muU|MC z!hVisInh#I^i{qEd^<~`fmP)nr1!1Em)KptZ9eN$fGR893AX}?ATl)Eh1m}(FbcQ< zzfM-n^5enC5+E@In%}WvFwU$9P6=FIc<#jL5ilDW3xZHF-N9o2Nm@`ZR9a;k9EPon z579WzTp4XxseTB*7B*oWM+Qq3e%s*_sa}k8<6mS!WWQ|RDUGWW*Ciu`i0qjV+2kN89EEV2g_FEvXoM ze?a)PwYRgAJ~+@}C}r&_8+uh$5^k{>diig)>FsFIqSF5&?O~lv`1^YVCRU}U*S@q6 zbmA(LeT+mJ$7ik(OdLFvex;OHpy7X#D-x zeD+~R0RRmD+?S7sDJ%QF(e@OAMjB!kLou2fjNnqdU`)w|tMlO@Aqzmw*^ae%KHdHk z^Etu;VTC2R02rYz8WNT0GF+0c?Ew~3Fo;*W8dNXuk#NwMDSDlA;25%Wg za~k-~SG3e!E1Svw7*v*CtTcY7|M(CIB2@t8lL}f+Bmi6L<1y5JJyTlo8y6|#=Qa1U z*3C`yx7+&p3a}^*4=(9~idYYZ*SPCPSS-ynFdzK5ZHk~3H}e2+xLPYjCmqieQ^exT zdlVvPc>9#0kk{c%aF)(RQ|Wo-zV3}V`=329mXUh9X*)Y6nq0V}yFJr^FXsm*77SN$ z)?8`JEEhZXfhK8qrog%jZdLZ)SdCU)m{jP$6s|lusQz`CGN&Q<#8s`(@c3nk?;H16 z`bWXnoM6=rhWm?L^G`NtLsh3s_5aD4J@R(R#1u$U8E7{!8qJn?Mdxr)U~gBFlY}qa z{FR({e+!=$2u|*VCF$)XgMtTH{PUSr`sh((TYHp2dWfDYMX@8{ z^Trw~y18u9(tTgrFtrtJ#HYZ!72qp!v;Tf)+liA9r&huCaL~AQWwaT$8D%1tTdP-d zy_P4Cw5GU&LnBtf6q{b9c&9jWH|j>xpi=aKR5+6Jy!kg$R+WyEO0rm9rXmG~N%8P5 zql;0K+&CTAt^Qjqw?e&VzE`rhvOn!Vxl{dkY~t)OL&qtPuu)$q2@gret&qzR`drYv zuI@_NxSbuUd^4+G6lpb|Fq^{2p%qrp!!K<99pbG%dDoT@FP{LfmP-Qt0(+EH zU);}fO6lCfFhcfkrg+x|OMjwS7pS8=hd2h11cil3Y8C`u{Km!owhk&8vL2Dg_B%(v zU(PU%9TV#9HYH*j&93@p%cqX?MV>8~xOliB#%li>qa(QzLg z-AEi!g0IG~tcYu9dEwEBOyzwfx|q~7Gfq_-vG-^*J5l^!h8$qjL@k8_ zTxtK;z_Y*GF+rK>**`H+!qTiZ6*-{MX%c-}`KpBDXu%G$<`w_j$TCSMuJ-?pjKX7e zE|R_oXF|J$j&@}rEHztyPx}7?(A}3unxb*Td8TWZq;*;>LjG39;Yuy(A)g=yHYY4h8nIDu62^OvFoI`SU$yw_kD4FZX|F zq61=79xDHa#RH0xuce5FlKGfp@Y+`l;Gd2!CIP<-iHfAu%|ei<6a8 zG1>Heu9B3Aqzuf>P=I1u9vl}O)s7ZFt3qN9aoms%OdcnT5)HD551ex%Dg20C2t2b0 zD^_qo&n3}nETh!t)cBzV1O}r-{ZuI-nV=eZZf>pfvM6-K>P|g3cn`W}a^+TD>V-b7 z2tDr91)W%!GI5KEerm@bXl-)nb7 zUI+2f7D;lT(niaaQHqB;X%T2^Xr#*GxIMnxV-!G4Sw%xqchjOofe+i|YNApJ$>ya? zszil!F}CuCI)%TtiZeB>*wEIgRS|=~C5{!Tb>#mE*GF#iI+eTeW4>jp9`1u|Ii5tT zYJ5##E;W{9v{^Sbm^c-?WCW2DrP?OMEXYZr0%1>%FeR!xAYZt)P3Qc*`430! z&x5;Ps5B~#`Y(Rh)a)UTdS4$ze5@p<*% zTwS>SdjYYgiWxSf#Qt&!=xhKY(?6kHUZnM zM~+PeaZ3vYf+bDlNrhk zH?bL(i?$^E=-3)$02@G3?+Ey(jBBvY9-VEmKmWE+6of9Hlk2V^jiNmTmyW{FC+eI}5xg zAZzdbG?kC(Zf&~FM2^+62btmN_Dn7b%i9W2psx9CNEleHZkudvJ>DMQjdlflKz7xX z#b(>2LN(ZQ%uUTU1@UYYpo|=6%SGb=2<#riNJ`o0L^XY;W*t$1LlO^Db$WzlhoHTj z9}CkAGpG~S+mmg+-GP<2{!t3szw%(d=SJC{BC2WZHn9P3!qm*=yp&= z%Lf7Sxq=U=tir*V$Hm%s?#39Y_BKdnLW@X!KANBZpDWq zuQJH|d{<30R_rkVm9A7H!DAdWgDR#~<6`=DNr{IAqyi2wPIL3fPp2Ou$ieBn1Q5LJ z9Ow9c2+RIN^vnaB2jtW@(lI;fJm$;0dt*-( z&$Y_0t|UQ0^)n_umz+FugEJLim^)!&H(JN0d?$sRdsoOMBWJ;&bMgI6;g~X!&1I|Y zqpgR-=F?egDETx4B(rPpYdU2BgI?XY`G0Yt6hLMYlU*i%hqX@zfoM5)zmduIc^Z$! zzoklCxeJ@Mn#Pu`u7@dAvMo`7n6zv@(qV4G@WP8CRs!|#>C2bf-0$qlR?LPUz{Dv~ zoZ`kbS(s=xrA$yPQq=Jyp2jMXDsebo08$*Y0k-66(NXk9%jL5loI+Jl;{Cw zl%mB<;FP!}A;yC-Q%}dUbWwne*-<%yWjL{#l*6%*&AHN5E@BPBjEi^~KgSR2C-Yk{ zfS!XweShS*We96AqHhnx(RfT*Iw)6P{3Sb#`|@BWixK%Ito~|E8<~Ks(#Jk%Qwfh| zVJ?&|aZQB0_+Vgm&l?31z%(%6FEz#xYSwhhLflOn_0fl_cft(_NLC024w(Gr8ZJN{ z8tr~2Ko5vad{|8&l9C5CeV;nSu@9XTW$|pZQa|3IWxi9(p#(mgUr0YsUABprHfK2Y zjWb1g`oQJt{KhPg**vV;o9@o)v;N}B{*3}YYscM7TS0qbAm5U8=Gz2uM*t?^hwbsT zO)vR-HsdC4`Kl3u^;oy}VeV&QkZi{nH;qKoq{8oCGV3O)=^wXFbw5v8)kjg$V(NW3bP22Pw#-aw zS>Y{pb;3#Vx&O5W)BRgJdY&;UjvXC*z`b+Su~cIl!P6i64eC4zkQ(Lri{rH{OWY=OrW0w$w;X*L&PF1w~Lvv2wOXB;+bmZL;5kL*OPBSwVvx zBPeTKXLyomUX9*7NgjdC=rK=_BG4V}8FTe;+CMh z6V?z~2I5bQ%`Myw?|3jRi~^fFq_~N!1QH#P>l9VJlIX80KyaLq=(6#h9)OvnH%8ys zP+S0C5-N()K~yftlqBD$8i?cnSq$p{DWvT?c`dU*fDKC4gYKVS(&%U6+) zf!ah0#4kKVQGK%f{k=~MEn{UfTwAcT`L)>`rifu0a4`@wMUhNyTzt6O{*yP)5svnG zAc@{0+4o8=#_g}^&p9h+2UUMO92Y&-QDMKGt16RWj$cV0TIa+%UjM#NZe5m%xh}71 ziTI{OJq~>!d0!9nN<`(d_p3 zYn`xx+nTeMf+b_->YhtQSnw|fBqEaFYe59Kz;5NQwbsFpChe~;R4@^AU>MNf13a^R zceFqGs1EdWLW72qi@W(`zCRlO_ZjCRk(G8iTEK?UbtgAZ&c9|>GB*)>XGlaENgH@E zvop!$dyNDC)L4S$wW|Wt=gzY_ z!c0?xl`M;=N_rJgYb&pecUbL}n!*LPgk{&NZoMS}tLM%6Q-eO23q^#&+K5(#g)2$4ms0L>Y@)#hXW4X@`mibIwxyzX_xf@x#;nB@PNmoCp7&MQ&txZQ&F z*1YpxXv?*v-O&T?@MTMWc)WX zBkbHoI$Ad*IEq05X#sdRJadI`xZ;Z%0`d`sYiuS%M0xb4uhCU5T(Vq@LAXq) zxL<3qGo2>o2j^-#19?1MoFX7u9o*1hwL|I`)y9m_zJl0nCQsUs=bbq1XQI^1nXQ}_ zluo4V z?Ey9Z^ORtxvU!8-|So<8OPOJ>mTUV9x1Z0`lmcxNa@DRZb(@6${Qpk_X=;LO0) z8#z=Bl!QZU!DA4Zqg)NseG-Q<#e1I|G3BNlsy{*Sewg#5v8J!Im~LDik z+>LXl;C*LJB2+$2DFgrbjRrol`)YlP108#!*<^f7bjHR5H3Q=r zK{9n9#6A6I>Bl=?r8qw8b>0a3O!)@w8{E>XXl1F24{pyfO(ju zBvEil-6k1dM=iNnwfS8J@UMMpaN;@&$xhYh?4ldJr+)|};&$yfXCE3G!o$OJk=6BF zOCqOacxP692dSn(@AKyvTWvNlB}@I%M)%ZuA3A4_{e&p&__ZbsNoppAQr(fvy+Foe zFaYrJQJ68vbllabn99(L9Y`ZXK$Y;Nb;LT)G>tgrW&jqPEb6 z=sEWwGzxkl1Zw(qz^xYSN4ZV~x>026%+%pVmCagR(^_LUUg?mhEC98k?Dw-yy!pSi zeVwd#K2uaa$yFa8$L!b`KgJg#&zg!Ho{P}@+3Rk^CYn{t(ND!w6AJXmL*&*7RH*dO zZwGNRp8&jtdS?=SI|m)EAx^aTeK!rP*(?l1AHFB$ib|aXY9aZX17=s;PYsGXlBO>P z?s%SFnWmo^s0Om|mlD+x*Vl~+L(xA)5Tu|7fsO)|`MYUU9q>a=Ve59CY9pI6I>U66 zdx;7F=1<8HJqF>Cyxp^on@9oDBr9hJ>j^FJ5HuxM5(|QNjZJMSt`uk3xiN zoKY$hUU$Z+5o+<*0LU;B87X?A)A9kq^I9pjal_e`-$YR&0bT`?In&Ap8dS7Y5HC&U z2MT7HX>qDUhnRUUk&+Rz$u&PlhXUvh=+Zi(8_6eN9Fm)a1em@bIE=pxEwl>0eDQrD$I+5@Iz?RJ!bye>XPN_q2d zH5KuR7`|_-^?*7-Rz{-S$O)nJUTp$mLO@n_9FQXgg0lKJta{<@d#ZgTlsp*TRPCZZ z3BboT(Xlg~)HDwWc59#K8?A>M2VpgpO6TF?SSrt!P%CBhgK9CpdlzN$wHdE8tn+$o zK3p~5o$MGG{-dM>gjLqBjrN%7*}B!vCrM{3|6DKDy#R{S?kFwwVy(I3b^9(r^U}uy z%2_u+P1p0RMB#8R*B)C03JQ}yZPM^)?uySnEG*2`O74f#;TIQk`*FgB$5xpw>42U>C z&deC>4rVgJ+PeCBR`Zv86|l2Aix_Pp~DgD++x%s27T!fj4EkY1k^o)tlLn=eVo-{x}mi8MI4 z2dUlShK;4tPSBI(Lg{cn;%0H(U5^x3oq(X zFn~+nh3&genVMKLqkj7(Tuk3i&_R*wM^#Oz0TL1@)_Zu6kqT1ftcncsS={|AE5ZNL z5(p@Xd_Br@LoSL#7`aRm12fLi&2n~K7R4V@2@2rPwm1<2QliImLy`EW;KqflSp7NC zJETb@Ol<5*jXD_AsMD!(wYFyGw=C=j-GkqfTTiNcU-UkH4gS>8mxUNUIck&tP3pN3 z)Oyii$ijt(6Ye-NJ-;B)s?4P9aR*A7*x&!IQELRagq(y3G3qFMwlLAI zrVLOd#1a=*I1U3adzuXUBLMw|y|Nv$I2m3aWuN9hTipgLJQwC~d*a0rtb`_Z{#;LI zJq!8NnUupNz5Zfhf^u7KGw%Vgo)0G>4h7O-0;yNMS37m4KFdJlj{VNS4pfGRvqL_N zgk?=)1U7~1(Wgm~+N-+ilEa;7J`Vs5%dkzD8#CGJ{|cb)t^r)^%ND9nqv^x>zjcq3 zegYe_>4G88iw+We#g$@*0$?O1(2rt?lgbE0uYO2-P}HGE6`dw3$ToCuLL2_-R+_P- z(*w$c6w`8ec>5GNWp64J0wgr3fhIkcz^Rl9f{7dcaJrCI^5zT1?Tpp$rHb!aB;XFv z)%KDfY5H`M4_LAz!r3;!Pq&}?wl+Id4&6MnEMWzvKa~NpD zZ?_ge&l%|*^@(LgR45ykSJN zfHEO_R#MV((WC~BbH%cHm-A~fC!0Tc5Xu~o!DLi<_CF2B!l749m@R4Cln8vlf1|3~ zCQ?_WAd+Q|KGd=qFO8?FNkdwBR#a5-xI3Jp1A%^gJ03j(-gGi@E=MCq?@ku7N5*)a zw|8@eLeRoLY8`9;QCTCFVoeu0pInE9y}rKQ6dw3Jn0rWtg+G9(s5%40Gn>6oZ8VPR1g z7WRpVxDkIcT!cuu0+w~3Q!tfH>y;>9VtR|OY3cNkH)s}_IBnhdfiqGK-9}DQ@U6!B zgq@+|diZq0Q%eW=+bP5Qm zAo#HuDJC~ON?7SR`T4&SM?aGT1Cm#@N9d@*z=%4H^s}+c1N+`w9ZP`p7S^bODAbl> zD_TDHny^Id51~!s!6X{J86Qq9+Jc4yyFWSECOZz!1rNDf7b6!#*%4iFxvzbvUI^dN z-bHA3f08~ej0Z61>46o@s=Q~8raLc8A_!3!SKcGLamVZXbw1Elv@y#QX(D9dC#{RW zT<=hYFvcODBjouFB}=@SRQ`xnJG+OQSAOITA;e9f{lxxJG~PjtEhj23{{HtSCxTkt z0CTx~!>O6|vIUPk1ZuRwC!aBxKOwsL2cW4Meu-E`NHwdxlqCA~dfD4&X$JOF;b_yKuhuu?C zQ^lf*+xF8efBm!ZdcHNU-k~NsZQSVi=v-Q@0~S-unP!=@<>dC}B;Cds`~6Yi%msM@ z5Rk02=n!=D_*_mGkK(B&_IC+?*yWfz%{(@$D{3g%qajjFWN-y^n^{>Q!DIk@)|#CG zfM6J*lF^BYiIvIcj6mREE@P~tqXQ520>Y^j_$>>*9^o^Jj@NmVAkB0|V)Q<0u+-j9qh)qn4+y@SU`I~a+4ex zwln-%Fuvyr761G7@d~KuCnRRsu2gnrG6Sy@?F9A6m@*rfA$0b*EAomQR9W^R}B#cF%<{U&1r1FPS^N9O<3 zWa0RMzgI41RgHodKrlg1sIQ;Vx%mAsn1!&WnE461(hxwmD{#YDKVAOkZ*E-ulay&( zi2HOQDhlP0EMh?3nT#|xsAC*8LG{y0YU$Gz@4d=xRgYdcmuJSVX0; zYEyf#3sr-v5vGl8rdYB>f)sc@LyC*EHVmy2GDK%j0SZ7y!U*W6XqF2je} z*-9f4m(gK)kJlP-u?Ze`!iuOX=>4y|tOgW1%NYwYn10R$WWGhObvNjZhf zqe7jkxbnYes4V)m2XB!syw8EZjLcR$Ju=VbU++N{nkBV+dj|`pnheZ8%5*Y+P`3fl zZw%Df-Una;_UE89n%XiGQ*#*~&jJ`Dqm!iX{{X(yPksI4seHF#D%I?%aSxsj+o4+H za~aK% z!H3?C3X{({9rve5K5!vMsZ_mfRvPXWI-c*Gl|UaCbrWQ3{CG6|Y2ikvZI<#e#`})FhkE9sd>TN%~E1r?yk$>h|Ul$DV zd)y8@9%?pjrY||}=h!9H8%qHQ2Y&ZsnuMRp^g6FVD0&`3@Uu4HysZN2_+YB6oNnrk zY*+xseiek1CM@zc!Qca4+m7-3E465*B`<$Dll_kQcW^1&>we;4oj*BSEk7?$Cjf@J z=It<643O?QTt9+2)*lHh+&g|ho6lZeW*^;+4ji+w5HDFLSHZq#5`mV2m+y~(ubacD1=ZBM*Tb7AQ9q*)i zr=z%kA6M)!o*;5^Hnz_^_V)I!Yt9QqEaw7$e|+clny=8v_P&Nkh$=A+u(&;(Lgl;t zGF;{3dFD)a;c*Qf;I6mLt_Ki!AI}Su%YRASA7d>m@!YR_!Kt=wLEJYrVkAO|uel*4 z#H@B-Rm`Qvld-XI@^h9L^GIzJtY-S5U|c|yY;xkbJVuP=)5veyr~N%VzrjnJ^yYR| zg5%yWv=H}D;N9rdiX;)1izadeHxhiYCo#I>l$W&OYmq!%Kk2)E&EDX0%0UN zH1t@;9;0kAIrF(BtQ)KfifhQa-B`$v7%}&NjAt<;4I#alIZGC0t>e|$M1`dMGSjDdAA*U(c^ayObT%QW@J9& z_7-bgKvyUgzu+w<_BSFV)L41!)Nr>5(p_Z;W>tq@xc)19tbQ4k0xFo}hNi!`>|EYq zafcf-iqx8rBcM=+M`=A0ikPaLzCG{FB{0^?CE&25BKMGlOJaK_jAIB#=VcdbD}+WM zOy_Cz$5}(d+SAP`f3&ms@-C+XEX6iY*FTf;vuL-lV%yr}KHrCIwXRP#RMH_S67g?a zA0Nqqf5T2fIgi>8G&)N+CZV?;xbj3wIsv>@I?zYE6sOyk_hu-SXkJ^Ja!O+ z#Jry9*7o|Ok|t0gNbLZl`u4WI$?+g>oYUl#lMX|x&0;QD*Bkt{$+6~y!S3pF*N0=N z`bQ9aR@riq#rMXJ8Lfyy;(h1`E2`s&tUziGpcNPcfl9*pdJ9KK&gCk=;&wUfE^@%U z0QOVB#1uXtQ&e1-%NX+Opyi_NpVz^-IiMJavZ&Uv1?bgz-BR?`W9j6x-oszuv-BhK1x=N7h|!R3fT{KKgy zk^`WxA>A<~7$h@j{xIA0)$w}qh`L&8b^lo3RsEw05SU>KE5`^p-hib;AY2UFfN1Yu zxtigi-pkpA_W~oCe}7Q6Tbibo;HNN4ZkIE`AQXrQSnlu;Bs_c)oDQR*?T+_727s>N zegCe@W+wjuXor8lKefjE`s;m@zM=QmP*oNsYYN;=R+Ial?!6yTjN-P3ID(7wUF)%b zG&RQl<@N{*=UaQ$$6cE30t)dlO}~EEQN{$H(0BBMw$)q^+F`+f~0$0r+i7xmPP;G$W!H(|&I7_Kj{t61Sb93u8%^ugObsFmY8PFd= z?0EWE!_fZtDH&kn%s&_NVtb#j{>89VW>P@qvHN)6^~xT5LNc~g{#7%74J=_VhZ&j&{cb*)3#(C_<)Mg&k3Gr`*<7i5*bqCFQ<( zT}Lt=cB9{HzgE~rXznROTn2WlLLC}vaV#N{D9;mGUI(4>4_sohTsmf)D_o($QV64p z-Fimoem4>v1Am5CO3BGvZ4AfCZnsKyiGg1&r)k&WD|{%9ltU>~ z1oK$$1l%J0L<^*f7YcL$>4A+&5Bj+-PF$P``|6Dy9KpFE76q#>v@Chz)DRM8C`gX4 zt|p{QGypF_=emVL!Jx?;jpQp!@#UOdT*1i{Zvvt>pe>(fV-&ck^}H3xN6LqZG8WpO zcE?OyXdd^7j6o&`3wFD&Y-Gb1$6P-*kj`1&;0R| zbVvBOIj$<6Wu~3XP$fZFSo#Mx@MTi}*7bT)kM`4Ty3Y-*w*QO0c;DR(?T>t*q%u6h zi~0`gQkPw<6nZ)EepxZ8so_f1P(N?h+Y&szw#oXUTdq17Xuo5+>JHl(W1qX-9i~DX zU05U<7~*=mNDAoDMQyF z!p>!PD9@^D?{XG~6{YKahh?D|&3CUZEG(SO=T*&&CItc}sMDIW67Lbc*5Q2NNnM8r zHjR9@5EuFbfyIXUJDA1vKZ$6M=XtL7Vn(x~k%+KcFR;nqb_$^yXp@H3;o+q^?vG1w zI~L(ITh4Y$jskS|_bWLDVGe5#`U5)AK!V5jmy3>_DDLC=KP#19o(UQcr^{`QFD@A$ zP_KL7$N`UOcSpcY)OgfPqf||XtFrd2uP+pg(pSf+_UBhn?UC2RI=9WLx|Ou(88bgW zs`DbiknQX1Q!BcAhoZ$(^J!^TG1vyw|4deocE9Pw*sRlHQzc@DUhcoE*OxdJ%SOK* zik+?e?NQc~0L-uLjrKD|07Din*auNhVXkx7xuVA2p40rz=@&jd@BLtMvTi4&Ft6u> z_kOA=`Zk80fP7o?o$lidN#yodh{8_N0GR*M0(zp%SD}EywxN_Kz`<1i)MRqDJP-bU z(e$8$*%JHDJ|5%-M>D*UFBWaF*34$Z&cmeF=?m!=VXk3r?(NQX99Qx$?=T}ian8Ft zP&AUjaDg90=!$~9;tgOgFIwMRve&)NJ3jb!1t}j4+&G)@IjqE&57B$_TMv(`XV#5y z3DB$JSw>5xuB6zbD20B;VYs#4(jWb^LkExNU@%YHG!9YXWuV7;gJ-RKhxM5|!SRrc z5q{{K)up(mI{U8%^D*w4A^ZV- z6tfOa@RNl&yX+P%{MN|G2=la@adGzW*w_zOcJ7Z=jaLPNsS=6&!NdN@yfn&s3v{f= z4!Sq;V(MfZA}Dw}ErbBMH0>~_AsX0i=0E`IMc#RGqU4gSm#xP`@2UKCAu9E5X{WM` zeYHUKJgnT&*nCsc!eyDQvDM}LAZdnXrjT_Fqxt9bVkUqAm^ay)Woi!Tj!woRq~=l4 zQE+KI-0i%?<)ifV;2L(*ru%fN9T%yzD0JH9f}T1)uC%7f3pI?i=@-D6j3-Msh0ku# zI6WfszCT3FkMz3~ahE%n1dpp%#$u1VIAAu7bnE*uMbuLzp$5N&tIZ8;_VZinsTN&M zeuP_$!Sc|I+pi5J{wH|LYFuWvBy58va!GSH{Q=sfioDT~zcHCDPEVIT(Qz*4!bEl( z?Ouk}jMeUr6V6@iE7jHGncP66MKIq+^)0H@nM0l9CB-Dof?OZ-kBZzEt9vWVxJxbm zgwHj@sD5sX0Cal3OapNOR`?3;+HLUhdbQ2-;d?sAf5tmG;@^=8`ss(zBt<1eG1); zbUAiJ8-^;nec~C;9!8M?Y^1e}Skr*}J^l*Qcghbc4 z6?>xGcBh)@X#3A5SvF^q!jNxGD4NILc^G&Nsi3A6GG$`2yQTiu?l%Np?6H^d8A8%Y z$ptgKe!)K>aQWQ|MY_EFYf4O+KtXJW81Rss1^xL>6-AProvf*!#FtS#`GSSML<+OeOU11?M3C8F*H8RPIfNQF# zpkrXL1ui2HlgIlM5o!(&5*ixDB9su@v~liw^Q3{}(} zck`On=@gY|RoS-t2}q+Vm|hY^UA={=8L3Prr|OYZ%*^yZ_FtPEWg3zyvuw2FPqo0M zwX?=XaY=D;nTaaeDv=l~_TX@ZX?nU|xE=M&%}xR7_1?eKG&IO42x*-9WNTKq`p`cd zcHzc-2CP;uydGkKgn} zPzEa*cWU9Xg2jk<;qA`SR0XAooLok05TOAomZG*@B_T~=CCE>gxEVFMS#oYIuOESc zQFJIp7%f#KRZHtDE4U8?R<&T3#`Hx?w z2xVXjo&^neaxpG?NMZ(Y_aPZzqLCkq3C~9kJ{<<lmZ(tA z_-buE-OWkgnr$)x-+z4|kAKo1Thf@-s25T*>7 zRFPCIHJ@deRx;fZ9F1e3N&VS<=+7nONZT{|8RVxI6?k{q{^uZy#LD(>HXR-|Ud!*f zwkyN9(rD51eN{K(Om1(ZkNf^+3lo8FKY!}~`||=C)SyAgXebG!>-*ukJ%a(u-_KGC zcXNBC>A2w8Hl)FYzCw8xC~xlSj_Ap!iiU9eU&FI@Ef~aGN{KqNSKz*`6&2Rm2s@BsxG(Xo-S6;!pPZl z@TCO^@j#}b(znq3sQ|g3{70)vQL9l#p_g9D{wKjI|JhY%yO&rrVONsRR<1sPmli5S z4XTToj9F`O+4_w&OJd)4X$Qx{harC_gcd(6q;3&;P3Aqd_4RT+OD6+BMr^zh8&h^N zVqk!t&8CTPRd_xJGdDE&WaITc~f8>?_WT05h|ibqH7D?r=?cm ziY+1De{G$CF3bF5T-&U6$nh-DkU3jrgTzHCxU6Vv$9_+iU=$(~9|~s$(d^evjK1F;WUvs*xtR6SG3|E$1O=5%Tc7SZ_kyvs#t9 zGVBj^Wm{D!>W$@8SL1Vie|=w(8}Y9MA)~AjV}^!ELEUx0aZOIg=`&zebgS?9f9-eZ z%G8pMgYH}{BS@&5{wrt>ZAY4<0C)JO8qzQ!Ld1kE3^5Q?vq&)|DH9C-#JDXbqf#Hn z%!hRp&H88vzwat0A_7)THa-}RRqa0HbDCy=e;u#xW>mzzI+`~yyW#mB$C=JQlO{|R z-zy!}3{quo+AU`np=f~;W!C5m{iZV;lV6aVBkzyFpiN5!1F=I{cr2cQR7ds&UIfjU zwmlmsxYn7|DI=?{4dd#o!hnVmd#xa9R;!dL-|riegQyEzb90M2eWu=+5UPFB+{&Wu z)zSdJpJ-Axnw~4#8=+u1MShyz)q+V(H08k9K}trA$b@eJxXGzY+P!sg`gwXbSh2E0 z;+>&f{4cNVL?o;W3+P)ZySGb-wzY~2EOK8Du9tXM*2~MM-zj7ntWYobmU(Hc4TOd3 z|5e?)m9;zFU=CnkQxWKL*L$*LNs72K=x!ALTW5Q%PsQjFvCn0~{{a*Kb)P1+BEBSQ zAAfv|;)knn0hyRK25qXA(`@2@Jt6V`>WRkx(-Q~(OHVKrl~-wcs?S&Ul|=EqrXGCr z>T*m8{-1tP8%QPpWNySZUB>{AE^Q&H_XjLTa#ZsR)=B$Lg%GmBp)X9?im;RzvH}H` znCp>Gk%V8Fhc%Gn6Ty*lX$5WS{QtCyQcPRPR!m#1t3N9&U;X1p*@&plB6DpY@6GHH zPIaT()ifooZAZXugDL^&mWbB_bA|K$)s#O#aTl`=ojfkWc=o8Ylv*2jx`1V|>@O_J z^Zkrb)F3l7`g*(B75&e;NAj#RH8q2a53K3q`8Qx(`9xf*YXC zFnGT^!ia-wh)LprGkQu($dMoQu|5hpyE-(1sy^^X$M7M_W^x35D^NdLowq0_Qjdr@ zTfEcZ z6kw;E#x~x>%gL{ckfIJ$=LCke=k$j2A~I^=ifH;O{3$Oh5eTAJ=CmXxmtH`jy3u26 zd&6n{0u@Njxot%z&Dc91R=pYFKv)e0x#L?)M_6;2QH6Q0Sy54WfAxThfB=*`uv6;0 z2wqj#PXEM^b)KM!8*LXj=nW6{tqttzSp zSk?J`&c3Dv1~PHtyDglrc|M#8+2dKB+0&jMuGH905VIbPreP&BHrj%s_PV4sE1-dV z={%<-6f^T)7${V;;N5Sl&15>=p%gJ=f6x~_;$i2TM-KdF$KpjV@FOD>IMAZ1()uR% z-G5Oq4G)M<)Pts)$jRWjj?>hKc6T|=r;Wb-+~zoi%QD%~;)-*OBbMuH|CvNA0Z*r2 zG6XuR6UC7pHDd0T%8hgzz3ZW}Z_sVdG(f+xV!{-KzoACMV$(2}rTs;TQ%XZdncAY) zC-SK$h_%Ah!Aw~YtI8n7Plm3OVmN{ohXRL5g}D?TWE3yl4S@~^gS12G+qZ9r%LyAr zSu63xF*|p<4Q2X18h^ZL&yNxt16K|qc zy<~=&i|ba^T0lJ#r?MWht5clIaap23f8ELnS6BsejfiE9ILis!-Vv*M9h*QxhuGbh7d`Ad}pD5o~ zj%%WafyfAECLaTSAe|^(77V{c5zFWc-Fc{QpVG%EjHeez#%U9mpHe&sIwo?wBd2Rd zlEC%My$UGpV_eYJ^+W6Apv;i;%LWmMZ4dW7#dF@qp`Tk1T615m zl&>|E^DaNlE|-6>>{-~FDp_wmh2vDGZvo|eCKfboO-Gg_iho<3`wi^WFH_g3G)y&M z4cpkMDPVU5W+_y%=~&9onX&3{vNFMo_6Hqpub%p&8IWuSMf^AGQ5$Z%O^A7?9cdHg z?Mq?*Zv+ULk9GpEM?RdfU-u!5johFue#@P$M5AN5uoH$PwwfQwz1>v-VwD+_(ZSKd zPU+^7(81&8A9U{h_UqyVh@RDwL?=1LGTFq~MMT{0;D;T~}++X)Bn{P)6co3*JPs_usSy##6HJaBzHTYtdPd%*U7= z+yL1PgI43?;X`|iUY9>4$CMygMZ<9(8dYZH?6AWPk%q5%BrqYghB5 zos}~lv5taH@b{m8xxi{-x!vj6>ETjf@^1TuuR>?jxF5}e=4j?2LBUDMN&e8~BLw&a z_Ud1XDu9$j`O{Lc^y?xND~KNDCEkFGeK@w$&{MKvEGjHY7B}8EjdT8eq5X}$hWcRf>2>p8wk0vK|cizC4+6L;3{@27Kib%Bc#7FG1pW% z`zaa7!d0+tF-Z!|vY^I?t715?gQMgf;vp%->j~BTbB93W--h}~br8NcXHBWI)W z#xP}1Erx<@^2d4qu)X;hFd)WYc^p?MnEqk>ZO%x}49A*P83}gQ%}4xE~q(vg}|+hx`sA#EK$(}R{b;Pigw6i9In!pdwmTYV7S=yJN!RNhrY`8k#~AH*xB3 z(VUN;2QcKa*)uDa{nn0-s7rD={RCnaV&jq*dy8SyN}?i3__wC_q=0w=(b97uTEw>f z4!dk(vG-C^0j-F{pwLdlxNrWCrh=w$rtE73A|k3(He*eB%gD+IL3y+Sc7$B$KxINn zRUrjczGX?9>~PBOp4TVNpYG7ucPhZ-ri9VI5p)R?AYddz8;joNa?sEMqn?F5@w^$$pcaU5nf^Ud|$cN zqT{`X^{b=R{Kh0hDnO`2>DEQLR~ZN@HND_>`!cXbr&EtB^F^OT)XMH@W#j#x0dN8? zFYh1iH5tzo31xeG9F(b)v^D(uyY+g0M&k4OU)W5NrVUW6|Ez@AR{in~i21!=Gg4Wy-GOZMZ1Y!{tRHuGFXufCxi)VP>WYfY^Rqh!lpxSwRyKnQMnet|2%&v6~s5D z=7rwxMI#FF02AmJgX7rN^8?bO*gAy^r?&PKull&eDZm>37N;Lk$?1#{{?|TXOx#)H zLrEo@DBR!{IpEjY!LRo+#!!M7<=4w>C7O(0t60@*eDRq+A8kF18_cy_4oN{L(ugvq zYqsg|T#+R>+X7Qv8p`>eSed*ycoj-kJAGfoPfg{cl?q^INyh-oHSYcnP7IedeM2ZxbQmHMLeuif%Bq&Gih{>^Z&&vz zA`#KD4$tlJ(Q#S%pWiQi{P#L%ZEuyf!NDse(uGsFz9+AivuMrdc6#5mDH0-PKOY7v zE7Nm-lMX-tD~&pSG14$#2Zy}aqz<7}R8II3d8DZuA|L2os=#vfwxk~WaI2&qFR;<+4>*DW@FpR8cg}n;+2$=Ci)46sO*+L`uhI2X0IkEz2<01z zggrd`AaY1O&lON-Q*cx^nmwMg9F~;2zu@;x1@k{GO(?i|+(dG$uT?_=0Y~J0M)It} zS7-xqdWMp(&?KtihKdv>46@mLnFu~dM>`}yC>j699)leJ+by6;fW~mK`u$onkVLNy zQe2A1zY6LSJ^+>5X4=Y>F1(}}P9ya(L?^!DqBqfip_eSjNQ5bP2YfPY76Q~D%yM^VJ`m58Aw}LejCm%C>j$y*5%^=zcHAThifQ@&QhLnxJrW4?u zfxkDkETAXY$EKF1DV@BG2yKXT#%myFfnT9a=i_gc{`V)bbao;x2TVyVd!y?!TQ9`$ z#wlq+Zg#KV&rLT6aq)3ANvjhJu#P3U(Q>!}URPf?gKjUolOF!g&2g%<=C)@!T&SVG z`^Lq{AAN$!?T;IY>Q8WBHptr7anhzhGNrz<+n{V!h`$^yLY(<4X)hv%RkiYYsxMc1 zi^3$lJ}$XOeIEj6v6n~07612|T^=H&a^zv?8A>-fS#=`hebkq`p4jiaU1|&eG8&a< zSJ%%-zF>dJJHy-B4#>}!{yzO95K{el8J=XAr?I;O%@sF_ji~%aG%mYbSilJ3{7uaB>b@{ zJVqT9SVb1}0NZ3h*Qgwf;pTl^Zu$EpT=;w}I4xx)o@_n7!H5R5&`hOil zf4mouD~tm#4U!ld#r{l_QJIVtQJG&k9U9pzMBki=4(WSj;6#p31D+&972_D_PMAuRBO!3FelY_^=CnNYDK14nq;OK*6)1+|kWl&XIXSr)Ffs5A z$S>SnT^${rcf~HHT;bpoM>!^y)}>g2ReFLxf9B*Q2oH^qivvO>k1msv<{OsNEKUo> z!MRER+HB4wJT&b3=Bi78bOQf0;L{;m{Ztz4b=ua(Rm{IfEdMK%&VUBmd-OHK39PaSlx#{1( zs>*hhnn}b6^k9_szifZINV`fb5wT<=u@M*Tji36zzy=yOt1AXuseEp?C>sjv?_KLH z7^z5as+{g60)=>~ab?P=a6qb<$Ig>Rpf^mjv4u=13b{+}*sgP*;~&s82{AB*(nIiH zmFx1p+|DeLKzWJ*@e5Jz!Q|>0MBsl1F;-(FqEGGs6&;4zy;L>%~0v*bLjx0P# zBJHvo2N9&lE}@6yU!YF`U0W;a+A9nr&xGsWTX90aO*D+2i{}04Cv1|uWM(A%jXbq~ zj350nAT&zWh2L%d1)VYj^Z%WWjr9ND#7mHyqySq#f}ru39t~1V-Y)-zB3g zIvL+pg$7R!8vlQMCQ*i3)xFch-QM5Cxx0GVyKu?d`St(_223)?l`^|k7h7F9(~X!h zFKXJGQA+kTkXr_sJ$4mALLymuQg9zC+kFT9Z#Ge1j;P*<28}w?bf>iJ=nzL*zuwQT z_Xb@QazN+zCbC8Z3aHMPNJOEC-)7%9z7iLqk4{!0UOvE;SqZRxDPG#-M7z`rx{ElP zIT%alcYWuO*hyvxEg}M!)gO#fHpCNGAxIbsWA-%NjMbNdQ5^=_0i0J{y!`LAOcO7gd_nj z7Hj|69QKc!`%j4~0iOOcWAX(*#J?g0|YPbq2A;)N`UFpve4V;VZuuRfrDs=M^D& z3vc~$^v$*v_!zeTGL5x6>etmcHMwPV(!uxU^jrG0<81}C`iWdJ+uhxr`ZO=8#3aao z{=FEftLP4m?c}|dE#a-yAbRqW=8Yx4kG>}?ekPx+hEjln@ahxS9Oa9qkhd>(i`+C! zf)J7i@_9m*eW>hll(U$DUrTIwFAqi8Ua&LzFm$izlbv*;(j23D@)y!eGKP4&{rWe; z0G`zZ=&;X)BJ4w%X&``qM9EORO365##A+Z6%&~vbWKbo&IB7#b7&=dvt`pzCe5>9- zt(Yr)*ec66I;mI8ksykt&q%U`$5so%t@u)Nb_R4Q1p`H~_{b>V;0eX2TUHAY5Agpe z>45ZxyIK(=n6{>j83)qUiJ~IVAUrUBtx@DU4e!lPwT0YURaTXwFC|sIE;qvx135kn zwj{*#Xacpx#cFOwNaM3np`Fyf)A@Z@Upf2xHl!mvQ=@$G zrWtP-{q){svsJTX{MA47%rzAU_Q+YY5>oLVzX*u&2-(UXK^_(cGTfr~-B)}Fqj ze|L9jdt&o5v!=4nMpw--R85{rc7A>x5ntE(-`U-9otK&?#MgIu)E6A@{M>+b9_lM` z{!E9hsTlFbBV`nnrIh>TODqz|)o&%@24AXeH!C#iq(uYT!N?E){EuAu^ADSKKSg#< zvJfxTO1r7#UyIX2chZd8@F}3H6GnD^{RePqD*>ILqz@(H9Fx^ZsDg7YHki7ZOMEKR zvN=-FE6SWy-Y;6b-T`+`<4$}@YQk(`YX>Efu2s5o+YV2GDcB8oLvW7o8dTC`DR>~` zfoO&r*U0@8`bP4xLGtc9l*c*sf^Btj1X>`k&skq53{*6LA?7&L7}CSm`xMueM_e15 zs84ZeChnmesPO2y^M^w&lPmHjsLax$wzf8jPCKlPnX+J!)c^`coA~U^dYvfN;<5JW z5IH*=Z?Y5(EW#sw;@{cBP4>ho?Mu%A>l_N(8Ppq^ZJ+_KA85~2kCt0@kSlLkijU!n z5&r`}1^W+ompElji$fbH{`cS%4TGtt%#`yCyX2@iKN-chCZ&9;$V3}SbkRs^3nx5bU+`#Z zdJHb5VF3sHK|*OwyUWJ?gCl)LH;LPl{nj=2!ums70(i!3#WsZN@%giKdQ-_uuFr+{ zJ-FhgI287zDVlJR$$oh?rBS#;;Kk}|i1Mx_T0eX`9M8TJRt2MGR|(oyUyX9PGxC-C z00F7?VP8!^HeIvv9~&;kGa~cxY#wDz70^15=u^zPStv!@EQKzPNAHktj4`%TezgjE zMr399LgZl6QM0(Y9fqER(=g(BGZ*8Z%1!_9VHPrq=K6wqW)yT$4odiZQB1aGBr@M! zO`#bC?zGzF@*mCCylO4-34!gln@RP~91r}?mWyWDhf_BB`>_2}lGgb~5x1B&kNi^% zWK2gnh+K>WH>mt`CrG-rXq~NP9P4gSnb&tW>0F*si`=UTuWyleWx+BiL4N->zvS2? zkiz9|SvhI9+6a%!EQpBfhr%&(33i4pQ%WUyG)B*V4?(s+d2smlWz~5m$&83yEKoRd zd-hO*XhBAzy97-nl0E?Li}CLm1A5b6lE=??4>{Q2Hu510L4$;dSREt|!c;MpF=PTx zzW2u+Ahe3tep^^Ygc5~A_sW(w>AP;Q$QuedGXdhRK@B&yKFSP)FNUvgGL{K#igzt3 z|7KrJFm8rb#MkKf*s)Cu)dD|M=o5x~(zsj2_{U5Hib80?yFvnypxFMU!)CX&(2VrT za0$L$YL6*%_7o=TU1RMFpzQFXM+w2thQ3&OAe#>>L*AD7yTZ^(X-?(W0jro|2PZUQ z95ICzGgU&xw`k}`t2A+;(d=45GX?n6TXb~}A;GY75F<*jy>nNAGM7js77UwDY!^kn zQ2g-eWFH};boa?=q4FGXX&64;l#;qsxd4Oz6=uD+Gv-<#m-A5v+Er0N?TA0HkbA75TyU)|oL zeSEXs?naY{f#WhT74vj84GAeu{#QlvzKzK1SVnIZH<)RS9zD!(EuF?*RaA=^V(@Th zu8EJcr%g9LfuEic#Up|{vPmfD=8V~Z{O1o|h8nl&)Ywz<#11F9i+)#$uBe9#A)3DU zGE_wKFyxNiM`C}rp@c`Ee!LQ?3UTqO7(xVV?<);$sS-qt&Zm=o0rA)w=zc%WO8S(9 zZ^b1w-8uUV%qf8=&JWiQh;sMve9MXo>!jH~`6rG4U^h`yclC^o`P4hgxmuX9sbCrA zp}nxhWZe^mQn%Ym9 z0_YfMR>WC1a*dhN@b(HJd8#qv18k@A) z9qS8oBbfbih*23$f<=bmQ}q^TCASuCYjf|&5wlG+<6T=iDjL3$vbFT=#Fd@i^Gh6$ zXf=@8OwuoP!u0ru*0i*KR(7uHL`%u0`7)V$zQ5~NlvD7&oLn*uJz*S~{(U)J;4ME1 z4GEKy%t}wq0Y z_K1sfkj=qF&uZJO>RZmw&<%k7(A19*$piE0DEs8BJO=Y2%$8b*GU0rQaJfGnA~TE@{KNmK1%B|JFH^$L*BBC5Xk8M0@Z^vX0V zV&Ld}innHpCwGPoZdPjVaP=V1md=3alh?53ZE|Q8C)qO3nQ%rPV)9&tcHaDFbQlo!mW;4-or~kf1$^bQ1G}>XU>!ag?i*MhK z>F`&_6BnOk3Vk3;uJ6?JM%q>n0~h{x?AEpxzQ?;Qv2~9m1>&lFildLE%ofi}iks)> zSH92iFE-m7IQvFY>0?a|^PR>Ybb1`_Ue1uCfmjgfn&LXk1>}5h>nDb9d~SEPju@Lx zM_yYpcx+(E1z-Xo52yVZ7hO8n2PFq>r};#y?T40`_;cZpsVk7~29f*R%=!W?nOYLw zkv3!f>;U&m;Plnj+h@@~?HbzX)_JGszLUv7eoG@lIM8X){{WKXs00u8V@A$(z4q`k zdAVjbq$ZpqowpT}iTcHh&FgcvIGMrb(%?oo07-4o*Rg!Zfn(3sK0PDnJyx&472+xE zezv*A<>fdm+%LTYI(@*A&I$4~FgcnwBL1Qyqaowr3 za2zCYk`r$}+Wb?Izu-ft0TxAnh~wpqw?o5+#F6qN>~aDv}abl8eRYQ)8_;7#k$ zzk3N7Ir+WbVjmgWr}OQVS30v}j>R0x#l4@~AjZ7{{Y5nVHEjI@m})LrqydmDqg&!~ z0xD&t3Uhh9sYKfJHSog{7-F#%7L1b?$zhNtdOzUl?RgREXxZ=4-^_&1{s}t>-^PZ| z$6p)FNe)q(BAfRUod%Oiz37fim6p`Y0RXF^N=K3SY?~Bhant+$soIC*kZ*jlpli@? zI0w2MGjui}c6fH)9_ofz+tR`^=A90=bMbA|gXqEVU~6h@?r*dQvup|v340wWMMiRFeqUUvxMwal z9%U4CXPd!3++LFg>8GTT{I&W`k5L`$F(pxok&WbrTd}`GABnZy3&p9z~lrX zZ`o#*h!#pM=Ht!kW&Y77FGKDF?pQ>?tyJRdRr{1_=)YU4`-a}>Cp-7#R4t28HV1*p=F{J1o*|=ij^4(>;81^NZR{P35>ca=aKQXZzxE z<%w&Hs&Jzee_JFvHALMI+_FXN%C4m}QsFK0EON>dBYAltsz4&NBG{#x<5%g+|QVkhN!5N{~OX}Ch9h6Dmc zfs1!dX7+%UQs7>BP>H%8)ME8F)T2K7Gk+o7nrP8_Lw#X}PHc(UXNpg23~nNFXuqI}+a?5}HxzUjNhL z)G+wEQwylEm>gt=5>OGlkF??$C5U~%FER?B32~sjyT<@SRcmgat9#ING9paQ=ri>+ zY;-sGQ)RXs7mZ4?8r3*5-bF%R*6^>(&w$t7VCBbxIp>eFsF5(c9fnoW)pX*=g4NF* z%PPLRw7Vk2mDIEa&3xgtMUytC%i}Gdr#jSNS7%qt;GNWVP>5jDeJA?Or9Ikk6!fa& z<+ZuLf}4Nn_?%DfiO=guQ^gZKD!}n9jxSMsPFvmF+ztEzQDWrOQ2YBk9nQ`Rt$GBa z>U6s7Dk?gwYnF7Wfat9Bbk>Jf5{$ z)U{Rx!B*%Z8NZU1WDu$r(z%uuXQ;Wq$=!xO4pA{GvM12+;T$F#FEo#S+KWbaVt8EH);uq!CW_g>SLY)$&jR2Y z{0%mmTU%D;-kG&BKF9_&oupz6lxN~9+doc6x8x@q9Y-Mhnvz5n01%5I(`btD7!iF5?(!7{V@SahxU}-%TeJI?N^0Nhm ziZ^f$Y1a*PX7E-Z#l%qmu+%(i*RxW6D^O~W{v8klWiL4nJkbzK{a2h)hBLB3%->5Y zm!yR5l@Uo@u;VayvJ&Hlp-lL*;F_vCEa_6IaG!#upqc6~hmOON^(kX9zLD;zUm6DF z1_g$f2o_3QaXzrSRMhCFG0^~UWN|}9*Dy`2{*kYAPfmtc`m+I#rutH2i^ux6?bF97)_>-{zy6 zdk8FALCt|S=U-BC`Xmkpjzw9uweefGYS47FR0bY_iS7lA=Dz#m^Jg@l8wcMLIGq1^ z$e!hHoMNmJ`5QFEqH@Bav94J$$X%%==gJnle!ZI|Bt#07PU$)rP>AWXfVw3~e#^+y(XlVbYi&>_J(Zp{Fd?3tjqkhuY5IV^tw@qK zWwv@z^kpT*h?wMVVTwgnG~v2}E|*)#MEMI_u-Jea+L}ZNT$X@(6s5C3yOv&TDiszE zL+k&y{2{~ux^1$^ghCrM>L`?;OK4E2PGdnFus(kg?(BM4Lm@tnrhOhhEX0G(0+9Nj z<#*sVDJwm1?2_o?d8da&F&*pf_qwse-B?%2|?B3?d4?$4_9~3Q2n?PqpM4RlvNvxOB)JR*4vZ=U5YJC2?0XX z03~4_LV5uU5?s+s!&S{|f(3<1gbl~yv_z6(r=TzllSS4~L`>zdQ=aadUD$6(p}usy zD1Dl0^JMU$-1pF|w_R`!RGb%1i0&8$8h!~CQuQj)+0(VHUTdOgLlXw5No@3HsK zIb2E;hA=8FwWxUqltY+er&b(Gr1;01|F(Ta-Au-SBav($$=?7>M^z4$$N`FTGZJ6P zh)TvX9D_YtahaG2KqWe=0ybGU0!c((nNDdCHt%<^P;4^2{G?ixHp&#yYI7dA*~6Jd z8U71Qlj|H3N{_%G5on5{(U5NmDdu1x6c^*F?zUe^XJc(c!O9XH5j_Q(RGmWO9^|A* z3ua9`mxDqUp@SH^Z8_gu8%yI?CKcw8v$q#vpt*vqP5(9T22|DO5ch!kC{eauGbO+ z7+G0uXizQuuaU*hGDbK}XMVYQf9R)M{YCY*i$IF1Jvl$;_A;7Uo8k*a)l1&lo{sqHz55H@M)WE(W_^OXQgUy8Tn{v2k*XcPAGf5U-B8#1rX4Y!NvtpS0$nUb>`rz1`4YC&gIX&lbMhq$&hr!3S&BbEiouHm0US(u|K!<=4C z8wC39G%($@-Mt|Ck)BaCj~laWQ$w!UN+XfUt=+^MJG3|0pJ&(<-@ zFB~>OSqcj+jzH0aLLQt^p;PHY$nx{;k3a#PwP=KH#7mPd*AqaR-DwDXPV9l28fDgj0)fgs1nNUHjTOAHqd0>h8$Go?`nGW2Etzp zfO0rdyH;YLXB(o)70^aLx}+y|7B+}$7p@*-6bzXu2%()z<>FL}Qg(EB&OiRCmzibe z{49R@lc+Xjn0B8gFfkQPN(FRxMZFa3hgn>H?I;X8E(6nVG7oP;kD1(GEM_2F7@J%7 z$%HYV&Y?9L@Q3__!zaNHpA0jvvV{IH{}@8A02%2_gig_+5icC_5hc(?OUg93QJbjC zU$EB`y#9oVww;1g1OX25?t)J{)qT}UCL?3jk0aSGDU*ST(*z-j6yAYORV`80X|ZSV%1U1BH}Gx&Tk^f^eoaX;$GTdp47wyOj*&1^wzRM z&9M*=*YOw5i)3t-`t9{PyctOehXKB}x`T6t&BL1?AuRUY-NfE*Q~bx5@dKTV;(-oy##DK#g!a(Mswa7ajq1c+If*L@{`8i9^_ zDQ;Q&D8K50or=qUN0J+rG7|8Q=mpofx2z`z8rmzb*+!Lm!wV#**9u(v7U zF{x{u4T!wAzyz4UNlJ!I)`lss4CC{O&>O`_)}F@xj21V5a4ar578Wr-ZWXV=qC64N z)srp{aDw>>D<-e%CSk9*w-%#iP_|k@o6i5~X#TDa$1KK5&B5D2bL4PRB z1*R8FE5?DBNUG!&?rSn$_m07J;mklrIv$8`e0$F?wsclCzfZQzcTLQ?zLzg^z&X(l zXY{Pv;r1-+P`ayG)|Iyh{v1!4;1JObR*t6^X@XIGGZP^_C?CwlV|Ry*{%Pb-Ko?~H zdB{;rd+64vsJae;GzR!5ItmR$rW8qHClf3Q75l23lvN0I$X6B(_%EZt5kStoKIAH{ zQCHE?&^AUu?d)kHJf3_WUNl@VvH07Ctf~dJXf7ZWkU$m~`J=(WbR}#hQ051V0S8K9 zW@aYnb(Vn+fI|HsF@(u0>R+_%&5nj5=CF6`?XS9aR*f$EPiF2Jq>BmBS&V0vTc6Q` zgF)qWL&F6nl{gC{^RRRBodKd_*44~C<$3as?amY54P)olHx8oB5U%`!zkg?>d%V54&n)BDFd1)GhJWOa#`-wR$uqz5Wp?91j6WfV@d!J)11nR(?vak1D0U;Uz$ZLA7u~z|5kU< zaN+KU-wkwC6xCUA0E#W`w0qe4A~r}fLRd#SIhp$sx> zulW0;IU%CZ?0&9Mtel;K*j!2?co@vTrZ(MT)9oJuugB3>y`w&j4zR6}nf(ND1Wsfc$i6xYo)e2OXwR)PGg-Pb4&O9yA9 z{G`6P-y6h-J3l|`Or{88^mDXiI$k30Nm1q@QIfV$CafRC=)62n!dyP1U#O?~tTy|# zfaH@5e)`$$s4%C^p`C|@6tA55B${D1$TTJfpo4c+Lj}b2_5d*=MuDm`%goSIgD6vK z;{ln`u(Y4zedcu%!jdy&vM$(B*=fF|XhQ$cTO>|Gv5Jg}+`bKtaSKI(YL|avwx`u5 z1Qyt(7xd9ULywYHp}$Es!1FSXni0qzio}@Bz@K&ur%5rJ4HSJ(xt_oG7;tLr_Y`~k z`n30$Z)3a8aCQJ452gZl!BU|hj)n*2kD}CVBkv+S6?%nG3B)+}{NiM>(ct@5$$LcA zk$e~IZ@{GMa&i<>Diu2O)yX)$Qk4mfg%ILo0cFXY0rl=Y2|g7Q1tEHcyip7J z!Rk;ZW`>hPPz@zUH9{mkHz+1{dPe^Yn*uqOk{&R=WjmZwksB_t1T!cgW1u7EgC~_}mvr4Tt^_%ac0`Q}H(a3hg}JcB;)LLm>X zl;$$2Bv7S6Kclb-Q6Dqua zK@;E%twcmFo6+&!*v7C|TPjFx;16M*w;Iy0U|`leH+8&R1D{k8CMT_883S?sfFzy5 z{_vcW@jJ@O-0Juoly}h%$BZVA!Ce6JeE)7bsjKtPE2q)NnOk;=!_o6=KLx({nu<(S zF*)*kASc(F1R()36R%?7rYiMpKR48Wn(L!ys+pG<>8&c8&JAu)y~NDoSZG7!TMEE@QM_QJq% z6O`?y#XoPZ#Zo8%{x>bj)f+<5@~O+?b;&6%A~pMQnRlNR#M37b8KHNr`qp)Q{&@Om zQ{e(z2ve2)0%qTH`8YCQ!K;_!`(6K{{#{PZc4%|(_v%b+JpZeQMc}YeR(AJ{PDpsV zyXW$Z_;w6f_j*ilZ^&;o+oxGv>b&^Mv#xtmU%nWS+?rrR*lTNOXz1W`w~_G`ct}I8 z5c+gNpCT3(=ALThJVfg=yUA2N8;GEo_%ZsLNM(yec;^5Z6njVbn>E}oAVHbh@2Bz* z8{|GLRVxUJ5f^F>sEgKO`C|(gplL=xa;M(Ug=F1BIEFBQ+wrCold8no@M7GdnGdM< zQ(9n^P=~nPb(XzI1+)qjT#5tGK!gz;1JyK9w?!%1@}$`K)LhIK@o$GM%pdx^qCX!pamPw?i4w0ALSeWK{OQvPf_v`(O~; zESsjq$@uZxU6T3?GChmHrfMZrXR~)A%yq1!bksG@FHi10v{-VRs=9)Unu>>FSeEx9 z27hGzJ&7~b>tLn(K|9Q;tE3WTjUj?*g|{ManB|ic&&lBY9xunWT8tO!pe{o;>2!(z zo0-Q#S>dzb)#J|#3OB{Sop0DW@GC1bnT|L0Zug_<15LS9RE!}r5`%D;FJo8S?Y3MA zq&e-eSt=T;%*2j~MLqlzx-LV_Hz@=JCWMnQH8-7~e~q^a2VXP&W;0X!O!IndOH$wy zFII0^6G)NK&KekoP#07(O$Ji_JEC9ny4B{i{*0WlH*_j%owXG6s;2C`^r$Xmu(^F00?w(FP4YFsLK?8Qwq5jI$G z&u&u%MH^=1Frg+wPFp`;VGv9qp4F{&3A09&3V$X%^Nf-4-ZH1$MgK zfGZO2ev1#rJqFZhp*I6s(I!U>PKaL#!f0cL@0Fary--#BUYueRiaVYJUiE&QSi!_u z>9>kL!ES7lrj@#jiS}DOMlJ!q`OwXJ!C~SL94RXR5$-~0OA*A7Pv|MI0GOdlMsd=J zX?PnXq3%Efc9MR%Gum9PF$;szpdNsFsF|=qev+WNElLh-g;6dw38ix%gQ$}}SzB@h zRQLXvUJ@h=t%3w`zj8(x*-dyxtaNrkc0Oi|f$$~!cY4R4TPh;bmSq)r9lwa=eO9V* z)b7c=d!1aJuW3YjpP!#SJj`lf{;`^S2wMr-EMnItqGKv2BY?+bk^ z$H`~Y>loAsh`1PZ0Lq_CPx3JW8!Afv6xswSk zVc_ZtJ-G)JK^J+-sEE_hI6RKEpICD@9z+7CIfE$~c1_6W_XnMcP^&t1T`1UGLVI9==%owWY=I!ewcp)39-LQ^89G ze}SaAt!nQg)lt1C(wb~%;EIkCj1mGX-6QeWNn;XzPxO~$8|c*2+0YgK6Ci{vAm-ok zlX5&Hzmq6C9_38q-?#Pr{T2zgPxE|JO&Ai<_=d^%=s%Z-nYx&w9uRg({~hKoOeS+K zRQb5%Cp5v*1KDCiNAzF+-(8WeZdr7pe97zjL@eioaCXu_2j@2^DO_j_h2r$E!NW5Y zjlb4>lh@!{k-_Fl+qS&AN<~d2uQ7PzlP7l8x=M$~J7i&LnUeyf6#=U|9H{u&0E{rEG;coSJyyvd#_-`z+-sqzp83#YO0#7t6olmg5avU4ro^B zv+J{pnwqN{E1(u6CN{drTBx#`n(bO^LubtmNc0{Vy6CdR(2AX1lmxd;@l*NCH)~Q0 zWjl^sE|bI2Wb^$cj$8&jwj{y`925k%*=oGl9S8?~h>_&-3nsbOprFCl*VgWCek+zU z-s+v$VM^fNVA~*$&E;B? zElB+mUa|3A)Sd%cA^T_Upzxu9&Drtu^~aNx)YxLVc3ftnsKL&*lC8Y~8d8!q5Ga zA!zOr?|v8gwmyry>D|`Z31-_DZ@;(gADkGTMCa9R{Z_(Ml$M%G(%W}5p5i$6t0_WC zws7PJIyyS^;}zk>WvAUv@A2{RMw_dtu_>K)E3@G&fbw`IPmIlW#bGa!jaIkww{(Ot zixBA?^p%g@&Tb%=T2&_N^%MOwYeT)VnhGQogs7;fK|gz|xIdeA_{8E0gVA?p&(xQf zm)%_Y+7*+s%5u0lomB_^cHWQsxnEjMmFhL%8~IO^-7r105RRusr)3uDBu+3ct$615 z;#4BBdS>8yU(L*J#pu|+0L&RFi0aG)JcWn>)o_1>pmc;gm?x++Cx}VOAy#EF6r2$u zo}l{QXAM|pFdIip(=xD<4=Bm|$9Bv3C=To_gpbAurn&wr`|;hNbK70DJ(TfvDutHy zeTtu+UhnN4LF4*XHswPpRP1PJdW(7DN#%JR9j?cWcDi3ga{8FPA;rrzBj!iZ&?{CG zA+vLHfdgHU4=gb=_GtBZxMGX?d*W{aS5q)6tu(Oyv*x5kd}1~{C|H? zjPMXMkP%aEh_TB;*U@QN%dF>JUQYgPdTsTUdg&-H~S6B&DYUVi8-USmYi&J5yF zid78r^7wv$DMvtMITr-UWGaWZ$>g{*@jD8j zz8I*qfKK+?MphOy+~J&&YHk)i$#GSbYZ;-Pg^?Flho#h58T(B{S7=b(cN`H9kst|f zrzE{BZc%5l!hz=!9DpY~5$uKCm(q>4eW@{PZFzQPbxFVt zp7yQV^~q(wApyN!&3x+L2U00MPsMvT)osXrhL?6g&st5zk~a$5UyW>K=px=I9FMGpkzwH=uTPwe)BEGRkwcG80 zdxYU6J|o83m?VoE`ariZGmmG-N_nF2*E+=28AQeqfxdRp&grwq)h(`WUM`1XQ){7> zPk6!}@?km}UdEQXGVKU(4nuJYvDOVF+rc$xM#7q@qs;ke_KdW@pE>QsZv=FVd`{SK35S>NNZmkD9tN#MXX( zzODC}7Dy%EJEY)cebMtB8vHvr9wJv53#f}38XBI)VrinNh~k$0;`a8~^|eRncr>2D z>1<(PK@}enX=j7oxOru*S+ffR*xzpzMEm+8mH)7I4<&pu@a z3_Lts)}I}Bw3Fi*W?JCk<-=K*_f>}g7?|_JL$O2W`|MbSb;Yr`o!#{>wyk$i7mS*h z7_7k??P0X6R6c-3B0^?fEDjl}yML71W#@Du{AjA{V^OP^AU7AF6}iU4)8YQ|5*&@E z-Fyuts7={=d$?w6z1E3uc)WMvxVI~^<$m1*>e3A2X=b?I70UQ{y`HSubUXy}y&^-w zZG+%iN1rcK8*NYb_mA6!f*o(4Z=aA4jC{IVv9YmoNCeU{5;;;;e_CXEIhb#j!Odgv z66VHMR+8wnXQ}edQ`K~IK5?tMIvt-X3lB{^O>z^TCTzV?54huwi@E>9xPuyeyRXh? zgfE+}-(RIRu|`15-^%v#2>d*fJVLMc87{TTel4VjT1r9d$zAj$J&GsY?hZ8_zg*%U zCyCZtUytWo4vvn3LP9s*Q&VF?H};?+G2e3!lDwm^eq~{y$yj1Vaz3N?yIVXd5*~ZC z@tsNhNlD_bC@m2c^m0QGvaCV+t?%59RzuciGN&*2Xd;8H)W^l!SOxv_+^-pg<58Ca zfj8g92;2coq^Nmx+6}X@vE(3-$H&{#9hPA8)#lr!&$c(!M;yV`3Z|~dNlY+-M%~~i z0Sz@ZU$w2SsA!NlWo&q@IPFN^NrUi#;M|Er{r^CSlA^-7vp*(O zwIfztTXpi6$RMIm2lfGWN+Kv-Fn`N?0pA{aAqGoov!$eFLW9{wurzXzC18lmv(a?Z zvEXyjYm190K+`{hOrg)jsD(LeYgS~+m}Az3{?s&FF3xjY_}K8QF1{>=A3t~($&JRj zUF?E?9g%=l?(3FUm25j0nJd^B{_5Pcta@e`nW=&O^y}H1Wtt>d+mbJ;7^WTV2*@Cd zQzTbtEKs{WwJbfx!ccGMYuiKVkE&;7+eI3{(f)$xAlz)b*zeK)4p-;1##vv)DXOOm z%s>H+9;?F*`0a{cz^}clKc2`q5Ixr!BdPs8&=E8OIuQVRRV>mOZGioXuB0Ez=5mA^ z+B&I(d@gtoA%l0%J<$U!S4^1eXawldd?uQ5zE$uL_Ies}8leHPBh{9JPzgQUoXPB9 z7$G_Zvp*~$@|sIk!oX$G^Ke5FsJ^EWC2NW7i;skeE4SvXl!hz3(MF}w61yqW0m=Q+ zsA8W*u#6}aLb%s7wYOICe}xb{e}|PO$m99*`Z^Y}nq(nuyVyA;AADmR>&~u%5w%L? zxLQhI5zFAL{h*arDSHzxdyCj#hZ&7PJ~EX@6qNVvo4Qu&#)@SYZ>Xf1fbAazG3Qb` zpz~YwLn9&OpfX-OghJH;-!;7`F`#VUa?^;43G1}!uaGfRyUp8d^f!x`WR8W!#UP7; z{c-6_x~=#8o{wG}46O6=Y({?4%~&d5R6RfcQt=I=2t72l21$kP5`%2!RQ#22_ZT%X z$k)JM!_8ivy%E5)#<#|2#TSt>bnGjq@%b|lQcrH1AsA$$^a3%{0dUY1mwNvDQTSQ) z%id3Ax*uEjpZ>@sv+&@f7CwI-SLTX9@Zg)lfamAFV5F|=_q*e5!)P9;)w`p~kDaG3 zTU*PnKNp&4)UFrnALDa!hOILbY#RYM6eN#MuHk!yYf0Px{yj4oN>-xf<%=AMD`QNi z^KQ&9yB&<8#~&5rY3S-k3nEz%g4iDv+WIJ{GEO9)e(@Jyr#j%vt81(O9K8j4E}E*Q zWrE!G2K-F5L*`c`d~LXo8N<2i>oPL+PGC9SYoz z$J4T_uBT0q6diNjPRQP#;qa#8x)(lKzWp+@O8;~3%KND;IT=gZuU|hQ`Re7=vP%Cg zJL~h@+XnxK!R`v<@7 zv`{9C*Yz>|y%!v#@Z3EA{Gubt(`+3_|AF~{3uh*Za&bLV}tl-2%jPzzPdg(-A zmLCmjXhssKD0nFN1rv|cBreV~T4~bhNU1J9j0hhi+N9hOPRI?8kJmpgM-zUfHO0rH z<+TR$(d*%u2pP_4*>*MD^dYf%-y?4pB@2XRH>##3HM!c#A(7~(!k1Qh-qG_VztI;! zX7PO-+2vBE1_2M61J$#0$$5TzDL;a)nFaIqluhd%)NqBx9o+>J8 zd#FHeW*VVdW&gJQ7%$K)?!~yB96KvrG)_p5*C5jPug{g3p5K1mwRODy9r#3@N?@X?qJq40|H1KgkDma`g$faI@3`WBuLdB|cl z0k-sk}v54rWm`Pmou&ThdNmRKeV2?NqFGt=v`I@0Yv(tzf4sy@I&6 z)^BhfA!?>~h3}9s(ByN;TDKIrJ91*kVk(_S>#1|v+Y8%ub=~6pC_NBWPvC*DM_M{J z#9rt-iEX-xZev8;4cyq2E)}o9f9Enc$b$LFH;Zm)9e*%HPor6Kr z`p+dFP^MSAO|44v;Bu!1#4w33BK&i&P@W-d4UQhB`?2KX(`-))2Wn9~U@B8oXCTro z8Kc7P=<8FgW$-vtagiSt9*iM3mDVzHv9a;;65K&gWpdPcJXjE&S)9?U(Q-4Q1P0xG zC^_9|XliLC#bpGNh+9}_8Yq&eVP8X&$J061)>aeo{QC97vY-Z`y`clY?s1fg+VbqQ zT9BI?9F?#Ea^@6Wc#@|W{_) zP?NUG@0^yKpNY!g_A?&)I#TNrq?0d{z)O4XNA3CqUE(wx5B0FQ!X&>wuf96nW(;WP z1tU*Dx6Lf0a_X@(vDu}8mSS2YAs(?aZMC;y)BpYjKwJpmB*cT14nBN7^DXo+#q6Ln z1p`103SMA-;^N}_yJvj;wdo1$-jJbodvI_g8V;E@LPJ-UF2B6Ij%<{dp~ns?VF%Q& z!zt0Beu&pMC9{)E;ml*Alxx&JXOO)CR7FiqM)X8}bA#zM!>lM(W# ztChdlo{Ai?&;V*`%f292I#c6NMs1;XY`}}}oAav9BjczHx9+!&B4dLk1ULj`&SF5KkK;hqiJo&&V-JF- zNkc#hO&5^qPOb-x_V+6&M}&sE4izgN%$deWm4?$&=MZezTW&4wQ zia*k$5Suh8Yhs|hGYcDQwto=qk#SbV6(0&t4T&M>(5>hVa=8dn;qduW&gaUZl(cpG z@DRw2d7N6m7GCAytOp+|DZp{I}y5#9WF)Z?C{>g6V|aF;`$PvE~vIpmTpoj6GqZwNV*&^{91ia&>ljS+(6933vZP zo8Dc-1Z(9d^W&lqVp342ALr$VH)_lM|B6t1oEv+e%*mtYQO55bwO{^Tk8sC^~{- z3@%+k#`Dz=w9c+5sTv|N`?svA6%H59EgG%kaW~NdDujFMHR4N}sHmt=6w9QuBZrqO zMD;sTCc63D-V7O|qZi9(S3fuW=?sz~P)HQi-Sl|7P}A{5{)3ZSeSC2dSvs1x1q!3B zz8%3&8BPhrX7rsqKQ_ zo12|eRI2IPFBnTIa4<49Hoi%qZrZYkw(uzq{rv+jHJt$)3++dZE`V`pb`=Da4GJ=A z>jn|)1_sR?54fu@Oul6N1iFXlR1J?v$vm|LDdm+MzPU%@4Fg`Y=s z-hQ=ch3dIKnJgD+#`*gz1CvEx@0H|41_nRRP5Fcrt_Xd;xK0g07?O50myo5BP!6@hyM zM;8J}`I#vIJU+}!r4gKi3~ zsZ<`^G$}-Ns4o%9>laL&XR|%gb(HDd#Af}O++YF^GfC0G5he1>m$NxQVmfC+Ykhn5E$V)fo<0Zt3|e3lZmOZefA*G zN_iR$A2}wOot1C&8ke{Me@qU!LO3z>@qwF%B*AiIaLT-W#94)Te5#4BUTw1i& zswpFV?NzlVuk7TN9w{%WEVL_jkWFPiIqG_b4D%>vL{((^;buPhT~28{ZzT|FbHm5` zHd-yLs$8?fW5*xTjk>3F4Zp`6P1R^G?>tW?4P#Q*_lM|;vb?NjO ziIK_qS{IU*`mP~4mE_y-E~ZPsy_2h4NO*WunCcL{`+9Q%^_7`LtNv7G2WViRN2gG| zqM7wPf%>DtxoojKlh2mmGYY&c;XO29(2vvdGVmm;`G)z{uLgb2PR6AGk5Y6uT2O&< z|lNk3zKUd#@xERg9WL_OjaG?x=)^hyL^ekkbe*;;`*jo`rPuRfk zpId+S;>1)ogNqMSDSI|HdNGf{?57u8ufjeWMLp>IPt&Do~KZoB{AUBL15*s2TNk0Z%s~8vz7fa+^wJE2)>2iBj z6#r@K>V`W2fNHGs3wmVmJ9~ow!zJ;15>^{4H)R?;^S$0!|MbFFi+ECn*K(Sc*-e`- z)Q^A&4Te?s>HBc$w~!z_6qLh_FEOe~us9)}V4Dncwi)R9%_<&n2q>s%Rvva-`i(LP zOpy*eJWyu}dG@PU2ArUC)Ph~b@G2wE2WUJV5f({DJ6xD>1-Wu$pUi%VY84V{>E!nH z*h=+Qu0lVdSxsZL-A*Ld|JQ{83r_^Sg$-?g z!DUF)xE!cpB2M!YB^qm9!+T^&-$3>Tx5i$c;Y%LLHoZKuNvBdxuf8bj>39b0d!RT@ zt6_x-C|ne&J<*{S{C+^Fs@WOILk25>6PxFi`?H{>r~%Wm<4q-IQ(qzTljC9me}mN; zXkshprMvW5iE-P?Ir=FfetBP4&~_j#r&ZRtw@}>jQAc?~VTVR`gP*;?*IHhNF4gs2 z)Q~TVm=H~R*Ny^_Nc=N@_0CF)>`%^K}gcjgy1RPLB)C!=8Y)AQS7j%ln%>1yes?A**@`mz$dj z)_B6S;zcN)XRWnW`LiB_)A@L1gKoj|=jG+U*4D`?m;@M6t+=e%5upnhL)(E4-+5{l$9CPyYkdEm(q#($5cb0*9V)(^$jGD>G?xlZv=r0Rj4OlD-! zXtYAq<;teBzq~t64Gi{s?)gFI=H@k7uLY@e7vsNQeSLH-Rdp3jnDoh}E`OqoO}XFG z(@+F}h1TLHdV-3NRlhsJ9*6Ry^hoqYoSoI(+#DPj&B4Ea&owlh%4+kCvrZ0{rU?50 z!Q%vq%)C|qh(V9)vp;=rb2OlqUJamS2Vli9Bj4vm^OjOnse|0ZtTDnRfoDiq7HB!2 zuk%V(U7wtzNtPgXUbUXSaKHyQp6~bX?+hvx&G`^YM0fJd6Jven;_=<`Q4kk_Qb~*X zg*kQVFX6H|VQ^5=Evwg?f2ukjaUSnIk9UF<(#py!2P3$DC%wQkMB4I7mdyu()69OY z)e&AT*X%y7*<&GZ-g*id>@!5#?+r#KbA)%W$G|=UM9?Gf7&K}-Yrx2c^Ya0zV8N|g zPDX$r#El+LEhjv20tW7e%Z)l$ke-T2fCK$(D-@cqW)h!C`U4!C{;LM%>}e`vt-%q~ zrdlEn`l$gTS{2sWQ78c$^#CNo1qr$ucN@`5%uAH7gjS2st-job+HhI*_OEoB5LL|2 zSwo!XC~W+@DjF(sYjd~lh%ExcF@B};dcDT<1%G`ub>%<=rBuf;F-GBW+=9@Z*|85F zS3X%)yt6_Hcx1{nI=r6yFpZe}8BdE^wq=&ZVvNfOXivx@G{=VP^3m;X9A|uAcYHNB#Z5&=>t9(NPrh+w&9@)uK)`R+`r~&#k(!5>1{K$X>$YQ@%Xf^T4D{4>|w_hZj=Cr z2%+!?{0op5Nxqb zzo0gUU-vUMD+FtNK(6j-d=AY`htcQL`0hweMRml4{h3iW*2QOkMr>t2tj(F(sP8j(HyM#n_kRCJf| zqsJtUxf|}sf!QZRBH$F(XDOiT&KK*iPMWZJyy!YIefqr5{Bvqyc|!E!dI>+WCYwy} zM=0XPnm7O2u@L$Yvw_+3bh!{dfRP9qA5{|z(Y8Muj1HDe?B7kC=NTu%(3HUG(V+6| zJ5GIkh-%HD7!;DeyqkAHxnV?jKt=8~#~mur?!zu1NHhCKUny5emr@jgb&P$_p<-Zu zQL&|xarmYnqr)S~qToOr4@~$)=U|rV_7eD8I~kBjqpX7>Vo`s}hFF(cgtl8f5UPY{ z{_m8n)Ybl-DwM1k6B|!6uRS@4Gq$ky&S9Cwym0pOI}82@`Bl)JHB8#d zQfaQnZoe9N{3MZs@fk;xrBH6mn`Oj(J&mnHO{vb-4K(Iw?eeKMsY2wcNAnINy$~^9 z>Gto9dx*6n_Ue=QGQ7_}*>74EI(<%Cs+^7Jl&StAV){>neWw4+{CztFkJK#&a z#_Q%Ybm-g~i4=Oy5NA?dSY*B_umbrdH|R0iL5h0PA z1Zg_e-jU-&KOw&)r=uK_LspYyj`rcEV^llxBfpkiC(0)a%4?dB?`deCzHLs*=C;cN zP*^c8zb{W@PaWQ*9&@{AO(-ZGp4Vq6t=}u9PtiC;SC1&+v3fe(a8IW~|(lrclCLT*Oc!CZw zo@x6N9s$1VGN+F$iEgXy?Tq~CG%p`<$+L@*qZ#RH0a=R19EINTxKBy(Aq;4ofFR(mG(k049TzKp}x8CS(e(y87zr(`UT- z#IWnFq^k*rVQ}i)vE|&&e>A5e*6C}uPCqHTTSMI>A>5!?;xkpvqyI_cUN3-#f2AL zNJflF9ElWj>Deg5;GCoMI6 z_;89e0m#@MJ#u94?ma_>4xv_50w^plW?;D^NK1)E-^qy(FmqYZQ1V7W+)sb{Q|2jO ze90vYyQGlbUA>wDMp9j*Kc=5`)|CqvTJ?`-O+D-T-~ayF_trl1>@#$Y@Yo6D!V53> z=z;}o6+)_L-1zYnMN%M^la;-B>lVxp9N5o%c8WS#t&&{oxC!G)FFyIy)0C9na_g#dhuatW7#q$9~`zV*f%q_fGk+;;1&6mRlp zg#q8ObLW!ho_l)nVyZ8@_vk^Y?&h0rqI$7!@7}-s)vtt&Q%_kgdyW_P>c!CU;&LJv z$wG(B>8`J@lXRqjd4K2KcNaadh?DY}>#sjpQ=``&O-v8xA0&@F@(Af&F1DtIhHI|7 zZsyFHYL38W$~WD7Gh_SeE3YnoX7QbO-bucd%rEtKi4wNbQXZ1?a&s=5e_5h@^6Xh$ zt}ILX{(J9{+9M#LA{)h;0}LJK`zJs7G2=k-?9DgboL7)fl9a1!*|ImzKYw*qRTcHN zWP3?alK5M{VLiis)>+fO_O-8($Lv+y>qq}_hY)gu;DuUFl4u%jX(>s|(uVru427ou z=rXT2-+1HU2Op#imTF4o>Z%C>;|bN~%^z@KbANFo{NiUnBYx(9&zm<-zb{ay%qBBj z_G}vS>@&~&;NO4ni=Y3T%s4k1BS!r)R0Pk-eJ1%Lh0RX|Gp0?u>>J;pjP3#jU|JDUx+bQ%$llAQ^R%v`IGXVI#9i{K{+t8*VEyXqUEA8G&Dr|F z`&wK%Y{Yohi&QrDapq9ck$ZP;+rN7!^J2DcTytpO9>I{T96EdqIWjd7p`?6(nT|YO zdz3wA1`i#fL!^TS-ip(tQoN~*l3gCtOO#b7^{!JP^hS;{nd~_E7Y%9p+>4hyu}IR9 z*e-i{@ueTTzPw_{^4DKDuy2p@aO%kv$N6K1$%oMKPLA83w7Cqgo*F&lI_y@Hx3l5X z6o{Y~^v~-*t*AV|N6*3@IV~sJYKzKO z*6yooZ9LL^F!7k#@oH>6zP5H}Q(FtGL!wJGb=-OTn(9~8??057mu;^rl%GwOvtoE& zFD8j6RU1fK;#IJ zgp$eHnrVuJo@q!m0>9zmW0I95(Vlcfbt4bW-C^#WR5EFlXzkgNpe_BWF|YXeC0}J zh_kLAxg7E};*PZ7xUu7^hF9@~)7HO@8Qcx^4TlaL{1*$Zru1c-dOg*<_~7inN7}-%xrdc)`+U%j4EYf z46?Xm%J~K7pU>QF`Ma}q%N91X=G0G}HjPXcxk&URxPSG_U%mGFYt(Pz_p>+N$h`Z? zsw(PL`81%4RtJitwk)Bjf6i#4UV#BgXgaztc=AIO=YXwO=E62DJ^a1OmSf_&3x?Gai6&A6I?>iEm^|c-j`l}nfgbPN30UC z@M9mJb@uGtyLUbM@Wb1-?|9|aS9a{&MZc(q6CTo$=ggf;X)v?z$+~hSA3k(=$&$og zn%rOHhIZ}R$!dzPzrKuphuG@m=&_@u#=i8$FU_4dmqarcIN3U~`5F~zvWDdMCel;!YJ7eZ7Qh}V-nnQ8`iB~wPrN~`p!Em*S@!A)5eWo`N~(uI(x-v%mIPx zpTsCXQSf64nNx~-UwHln67IKu>vq<8IPbjkCr%j8LIhk1RCaUieVPKs3oc+U@?L%W za;1?<<|X``yCznX?3kq|N_J%B=ZEBn+bP!(6|&q8pT6!orpZs8GKJ+G$g1D{)4M5> z6C0z%~jfCCDmx=&Mi{mt_|AmnwZ&sf+NeSE62@1?qz#XW z>rU}1m>B7=*<^)LT2?-B>e*{oy?J17`=r8}{d<=!eQNmVNt862RF(|p@dtjke$CsP zH@>%T_l~P?ysc{VL>#^U?pyD@`)1v-#B3dT4I4go%%thPdM8%FFeBlRhnKOFXRUnm zCHB$gu(_#m*Y*#p_v~b6=i20aACyC@+82&Q=Eb~k**FE64bXON+em71~DYc z89RCAfI*e|7X_Gfxum45py1@Mh-}Gn>~IZ#X+Q!A496K2U~BvB0L6{dplm?l<}Ix;b;?2M2ZVc3e`I}?)E7LO@u0lZbWZ(aSK)@13T z@W^VCL-LAA6jGKbH#LhnA~VZ-TjQ~1M-Ca^z7wBCEv-!(>vnHG&V7{ZNwV^BcjLjG z4g1Xvb2QHE9(nWoj_Tb}>L`iPi3O!&PV|+PB1uXzA~H8zwy!zA&nWh;kzl-oDg8B@ z2+TahyLp&$a5UPKf=TQ7UYOGU4B)@+2=M2a`PzIz8sAZ&=IyIo*k%(N{ZMdQO{6JD zA>-(A7X=d+v`R=8%sPw{kBM>RWF#G|9sg&J_S^{{Y0pLOdY@cmZdU%d%Cj$>ajlt) zgssCOK^)a7zFM<83QHyOJ)Cm3n{U2lz<>%0wa7h*RX5`Yl&qZWqM|}-`u6QV`KZ8m zb=Gv_n>-U^O{Zv>lFXFe!-o&{>64gS3X-xTGLn6I_a^bj$7S{Y{Sp|VwzafU_Jvhp zVF5_+`XU8Gr5F`h%7*u^kNm(x4^l*S{dL!U`r2#zS5!#Gw!`S;|CzW-0Vat# z%0kh~VDd`_ zGkWll!BjHdzi81Ncii#T+so%)I$y>|*+`%DrmMY*$#Kt^F>}LJ8_4joo+FnlquRT7 zpF8jT52mH}Aa}`ng>7y5vNeqiceiegp)3V>3KOvCJ!$E|VE#TR|)3tu3!&pM>cte)*ee zOzTxwUv>9A_fSOq%b(v{!Kx8Qj!?w?na_OYx@)f+Ja`bj)~lDmBwDDr_ZeYM7He7R zXCv1cUsUr83#o^d9toZw4?gwjYd-ezk0lbv{EE=6+m#CzP)53R=~8YPK`OZ^NE7bh zk@!yrUtUgKC$r{%d?(8qJbTZUdnZnw95WrM0lU7$Bk%mQV}4UseP%+;bYzG;>OoV+ zC8|{!0SHnx9J)$vH3cEJ%uN8Y+iHg7wp!*f53PWO)^sF1UF9xjRACswxDlo_0g%!v z1Lm+_t5pJs8cbTH$pyHYm{eu}kfa;|6CX#m!Z71Bn@l+lZP#L4@5ISEjFV@c_r}Z5 zXgZSe#@AkWoNQFTvI^=4y)mqP=Z(kjzh~3hcUpPMJbHNP;)nKC@BH)^zgaP;^7o7G zUbpHk@g=r1&b@Tx*eOw;dSfP>)xTn}q$9y^|6twPmCFttsMepP1`Mg{om@Cl*viWX z_bIE;Y-b1dA8*oxc3;p6|Z;{NuWP&Gk3^&3W?|<{QaP6ORm`JK7fJ!#|X6${)S3cYfqP zYtDu5zWvJXWIB>_vHbPt`6{w+&kh}Dvd<)}t$F$ZS{ak@P>Xq-p>ui6OU`W#1Quw5 z)1(sFbR@PWCcyULG!cv&Vc2z4zbvZ^m@S6zp@bbZn0hW;fUAi~WyTYycg&|g*wP?7 z+vN1jAC+G|Jg;}Jp6%P(m~6;oBy)T8)Z8gM&m3(&EQh2W`Q*f-d{sf&)XK!4Q@(EK zMmJR3^zE6K)$y%HypWXSA1==8HMlmpea)ffV+}2J**$WZEj}Qp(EA0zick*Sv1$*? zK}YEw;2Fi45lDLk$R?06Ntku-&g3}PUT%-@h3&SQH%rFc=y3AyWeL`N$MiURnD3}88a3?{WMjLGRPJihh$vn5&0Sx(ZeaU z?3zM;bM)9T8#b(OI9^YtckbM|l#wo9vGRqbFCMNrG=AX(@~32Jj~qS1>C>Bn&9YfU zub{AiERRexC+*uqN0o`yB_-G2a03OH_uO;$v(GGMwaatn&7CxU!ra+&-d_IpuYdbn zR;6Ut^ZDnU$9!HLC5@SZVtVqCLx+h|Td6K&IX^BuoT?HN7Avi(sUi25s0ZwxSZ|TV zE~(ojE63zqiV$C1y7aR*+`!)DT(@KlXU>>0X6)E~)qB~vMm?GWNpQ*=8O&#&TS6Tl zSxAPORivJH;_*!%e85_P%&Vn?uf6i>9UZf(IlC}u*9TdL|UJ7$r$q3XSK?na#aGVe)(Ly~oO!Y*@c=;ljD+%w@k9y-5v%Zoj+l zzD2*g?<==`g`0@`l#c?^wR{?^c;_8zkBhRB|$23H5t0FN}q$&}_%G+eZ;>b}h#5UaO z@<-vlp?Z}rXSy0I2I~;E_*rL}70gsEqrzjfnUnE_20{K0jGJDKpP4jo$5qMJme%4AJ!BrgiTY z756@Tuzj6!%*zK3rHE1ohsg+4BgXaZS9YK}@vysj!|D~UmmR8U{{@fV!zxFNE$-9T z?6}y{{i=}@`}FI-cjs2oytL%^(`H{#RzAp`mTsEw+OcKryKipUum&ZDIXV3F#oH;r z!3v@I?vA$E8YPR1%xt4LQ|i4RAya3b-?!gScJJ6M<|~%HxMS{0m9))FI zJq@gKn|R3AuhOQ>;*z5ZDu(11Ej_$NHo8fCAxMDzFV=|UA4$ok6D?aBj#Vb>G^K*l zZ;;fQ!{xG@twQ2r0_0ImE^QR@jQ-DX1h_mX4Wr}aBAM8j*$8lNM&B!mFNJ|t*JBjk zOX5|!oat(;7_39sT)%yL^<(dv7gjyEdG}gYp>5Aa&b+QyULVW@FPBRA`Jzr;)&Kdw z|CcGB6#G20c=7-HfB$^nH*KW4ig~+i(!xr4OtA&e z=FY^>+>}VZFuk^krr3&Un3a{4OfQ9*bs@)(pYY%V ziw++?bj6idlIlW97H`z3(d2aY@2ws)ZtVE+i5#W%O1LmSf=RtG*^?zHufF|N6S?uV)E45+Zlq{nOw4_BXg-W8!(|&XYwg)fXwJ z2@@u=0rZOHD_AWKJt?ylZ@C*0I;l)@M(f{y|FK6NVRN3p{U85h=B!y# zjU`khaTs>eucT!;5h%$t?cA}GlvAJNM!f=S2RAjbvG~(ZKK1;IFEEFige>)L%s!?j ztF){P63W)tp@js~M?dlrvXlF3YU*l_k?m!gEtB85FvvjlD=DF7l6}c5k|RJ;X7i>G zCb82uWoX9s&FYsF4~1`+f8_FAyLa7r#~lwn@-RuynX}LSw^1WW4=rJ${gNf${`R*y zg9O8pe`igfZvFuxbyZbGjY4mxvN9K&VdRXk(=oZxv7<*bIOK6i8uE9O2YXh`WPG%^ zVdaVyKfCwdVZ$oFa_cQDlu3OezQ&CmOC6}x$1!Q0>uTzhslAiwZ;hF=$=q{Ny=;53 z@||}m@EksT#DD**zhYriE=5vZ?2*m*EL`|8hOV`(^+@e8c0o~dQ(%*E(m+&$GS^S6}_&m%q&7lB`&|apOkPos7iaed}Ayo@S~q#fFooB(4g573AfQ7%`k% zZ^OEE)G%_L5FxdSTv|MD;Y z^5Y->c+Q+T3$9o&eFmG+pYnSNx8#taL*HBV?)L55_zWbS{q~9#Y%Nt;RYmTYXgP^| zK2WQ{f($?S{`V;a=H~&K;m*%C`}bF4&gqUHImiPCYF>Y1+2fBr#@WFL<}M#PE2@|s34|dG=D9Fs%iWnn837V}7hUn`4>qiR`I$$l*%2j) zr$s;eE*rd_fBDtD`y}RBo3AO;=UsR6U*GqW|8L!UD@0Ffa>urfjK7=$1)qEImE$JO zSPq>2rJc_578!ehECu7+cc84_jD+U|H$;g2Za_(pqd@Oj=3_AdYLL&A>!$Edjsg z2w1J1Idqx3eCVw3Vs<1DhA^1t!jvs{;X?`%0WeoOWy;Uu$?M*;Jgb1b;ri@C*{sF_ za7YDZaaR7Mg3^4-){{HuxEQa7j3i566cr>^jq(Ei7?kFflkxcfF?8fi45xJjsLf%O zG}lGc_ZJJvHyIOY${4zfA|Mh zGu*aqEBla;BBA7s04Of|;SYaE29Q)8nU`6!W;cW!@7c4?Ce=0dth30= zP{9PPj5eiLWYG8$PFWc_P`<0fz@zD)!Gl=bkYX+tTO2cHjM;0iT-vwqh>;`SefJ## z-3p39{c@czxmCt$u&^e`1q4gvOV!N*IrBJw4t%-(MKP->nC?z z{i#phddn@_cI;TaW=&O96)}(_Q<9ORM)My#ay;>eKm6Z+{^yT;x z$>~$4nrm^syK42F|8d9bZ@fYFwsPoD_M%B7*9H%v#*qCfR&WoFMJSp$@m z5M#pD;p~~Sbm@zS4j+2+%{L!?^id{dlbNL=mC?}P5y;Z2Exv zk;(EiX3gZ5SC1q;S$u$CFsF8sdz>@KI-86G$!^AhvEsCnN9Mx*?2Vs$`>i*bcFo=K z=p&DifWG1S82RK(?9(aH^|@q&EK4N{`rg&KiGWld*6fMt6%*Z zKLdc_WBIP1{&dR3NjKhjBcE@UUrHkn9yoxKPs_k}bIU5_VdBZSOJ4e{Q}ryE#Qk6h%3uL++QZ7m-#Wd6eI*%^j;))FMW-o0b% z&+quh57xcA@Y7$e96ize*eWP0K6~!`(vphb-u<2BuPu=>NzGMhH!ogr?L`Z&mE{Fh z=?Y^fOsg0;OtOSL-qqC}Rm_m8QGH5J{y7x=(5kVjWbSaSfrqNJJn@4tFYTQi@;-3r zNK&5^R7#4CJ}!U#x!S{ruKD8Er=5KP%Rwqq1fE= zSa|C&<&|UMIW(^&w^cOh(x!_SFHEkVHt&+xUwTT$3MC6jX-E^N&KWdx1arDo6GzA_ z5|bGQ>(D*R0??6+0+>AM%HEDhY~oJi^den%iq~jWP_V`C0u(6%fGF=#dW*sI=r&0J zVXBltDy>Azp+zV$DD00!sBWnrm{r)5F9gvq?R!TpWVaoSwX2RNw!v{Xs}Y)G&Btn6 zYFG5gAtS>t-qGCtmc32&lvRde>af$;Iuxh;Kiql>G}VE1HqEn~%JfY8pVtU5O?Azh zHJSCY!gx(m)B zl`6ekOW~5TxmL#Vi-Lrc1m^XJW8z}$ihd=z+fBjdwPlo84Pkrj@tFA(( z=JCkkBWxPOl7AG_k>{d7Z_5W;?)~}CPlng68)==hXU)9s+H0Bc&Q>I3`gkmFYdvw% zg%{m++gF(XPQ_&cPuIG4-??+gmMvQ^yx@XMFP%Se{J1`)r3^UHDj@?%qKb0PEt@x! zsHKTg#F3*%Q*p-no1|GmlEJz0#?O+2BaQO=M;~o&ZKc=r@#7!+IQWE{ zq;8J8o_x^Mv!;?cX&|KjsspemP;$S3aDwX3$io&vFUza8!qXEgaDC5nXA=fCiI5;vq}`H7t& z<;uD7v!7+yDLqTvq{$bBw`{IL`jYwGBn`>STWf%l@*^=vg(i7+mUQI$zv$wNM~@lv z%1bYe8$D(?XIX|J89l>t@x>Rj;NvfU@k@$8@xXT8)2B~o^Bj_=GWcZwnYqshGrWMw z|B^3T^1}1vy8#fne85q!_r))LiS5xD{D&WTn0dkH&O7(|>u=yn16Fsue#f@$i7k(s zo1Jk;3`F8zvYL*x7b(l1+;tbz|5utBy)gLDdNxqWA zJ)6{Xrbdk%1qoLgwW2uXql}W*n{N6tlkVB?Yss_E5ei>sKkybz1crgXJZy5zrA6wkTV(RZT3v?xAUAwx(o=1Ax$s zHp-XjH!KT-U@T_y&`S~sIM5_iDFI%Orcxr03Q}4HAB8d5{&?iZ8 zGb)xI@71T@#A)ZuyKupX@l#7n`;)e#o5$7erN`~UcjC4cxGxkU@`&2Rt5<)8R045pq>((l~<6&KG8>Tmr0 z9hWS)hTmRIyvTBDwr|<+gMa$^%4IJqAS)+_`P%1PxM0$ZbJ#YgV!*KM+`MB)4sBSw za>*a=d*{ui2WzUeXq3{+zr6k5XUw@c@fRi8sa-p_eDC(JDB`9aCq{gI_#) z`~L1*uip7De~k+XW7fHs-t^7?GknY`(;0sLqksDSeRt`+b0D zA0GYNaMrBzXPi4fB*IC z*iGB|br-;wzxhue{p4S;rl*jY@2lQ^_51($m(2YaC2PX`>rWq_IDL+Z$xODwvko)k z?4*^(VDWfQx`ok0NCYozQMP8QnkY$>U|N}&lnikI5Sr0O`7-^6WnmDE#cUpWNdf@} znuIDPKwvwXYu{yC?WSYwX}}lZ;rZoriYg>2DUn-2Zf-pK;-U9x;xR3&XAUIHw>#0T zeMc;BFHrVO7GUy)qopnJC|QtQd`-XcJx_UXJ;ii;b1iegk2D`^JJHJbi;)EtL}fm| zo`UYgfwzw)pBI4A8Ww@^zHi4GnkJ1JG-3Ealjn^5XE*|s%l*SY`~%o6us*TOUgi#Dyh5)!jVQMSOC99q|Sq@l4Mcl<`kva>t)eZc2T|B@kj zIR(tow!+|0B6K;`scRw{6F~Zauf&HB9;6_M)fj6I9H5GaWF<(FgPhQYy za|wVrNs2KzhYlX9Ianj1n(yQj*_?;#fpjGCF!h)Z*ujGaF)^7{I1`zvmKM$n545C6 z7;P##={1~WflVHjIa!=mGRP&I*`gv&G$yQ_M;=E;1|->#>V4H)H*cnTjphs zBY4NoomAoRf}FYJ2x28CddCjZ!CTlSg|rdpn(9T0B`sPZ=d3XfjC^A2=d3K!g^A%x zenOI3+E-moy&I*4?ASA`au}C~qzDO#uqoN=SJux=8p2sy$2jaK-wl|-M+12G&Rtu# zZN(8vE?uf(S)-G97-52Brawb^sOA78!#EI${#mzTrgywlN$^}S1W$1yfAqQbI8(9? zBT06yMwWosxpOBj*dJyn`R~Dl>7AT7dUs&oejIV$Nh$Kngt>1|(of=z(dNT}lf{YS zdfvNx5B0@Nif7Czqmf7?Dl;%ON8n21N;p(g!zlpCE>c|ZR4CI#{Cr;VQC7WwUs-8C zMmTY%lAF~K9!+56bR>}b_me{>UC8B1?w))i)2p+RYsCVS80a$B0~amPqU6PHIjk*9 zz}&x_1!kDBL6#Jp3c!5lQ)54Mm}1mcoGL(QO$KdJTeCI3EIcNO#$b||08{UkC3v&R)kv#JYs#jr$Yz{TNc7@J zz)B+cs2CK6U@?T zd`c@5tPCPZljT(LU;_Lk8PC6TA>V-)-TS>Q8{boA){*4j;RE~LdHn?@R`a(oH#a}| zBNl&3SQeVS3Xdq*6!z}kzH8g2@l$587NmAc1wl%vx%q`-CQM!R#*2IR>@XRIRgEbh z$g+}b2hsi+BCTjfj-I%(U-|y(T_&DgaRv_^PK2h?lBwzY_-AfDa;RqW`qk3Z($qv@ z=F<=VV(H=sdFIK=&gHR*PyEJ)`eTO=HZ~sD0qoVM^qSBAw+WMH@%PxGP5Lj;&a>q(?GUangj5Nj~JA)R+JHF-?gjYnE7 zU#5x57Dk-{h-y;33d&{*LbbD{Zd=2#{f#x6o!?ZyYxD7Hl9EGm5=%+C9||=s?LSs? zkFue*a&pKTk*z1%niD(NBwo2a3;JXi?`y2))?|-tG+^X5ZE9;m;6FD=NkkrQI?{Tg znUAa_Q@0IuJ3-GWKADk}_$_m5d^ofK!lv&gUX>yQngBo>NzO-(`Iv0+o;mz;8UZRz zs259IFg*$e8sj8}?7X+M}L*Z4T zBYdmpgN}|2Ini&s%$BrE{a>NMok0jT0D9ivbEF~A% zFOjiDc3ot3>q&Ltz=4CBsEuNcPIf@=*}fMVoh8L79&yb}U1%60@D4KrP9>~r=w1J^ zGXANC0I1%Q#4G6-E(zQTCrzYO@x7=ba=VFVzX?FfKF6{J&K zIa0!gKA|y6SGkG?4IW%kF@VGqsVD+dTQX$@Lx<|HK#MHoZT%5S(pJ;{*OI)8Qy)Qz8S#C+`0`J@lW*8V+ zA72(5Nql}o^73~^8M=aP-egLWlVs~NIu1X%nw0U8rj;ctNms>h@t^KPy*Ey`qbPoB z+;luuh%!8*|1d|uI!&j;b>NKW7Z#s)=|X;?dhi!Ncz@NKS`tVY>`tQ_)Cfb$P>&v} z`Snl!-@*NR7kv6lOa>41h0n^-BvaZ7iNf?lf>O4?CkP`5Ob(eQurqfaC|e62X2KIh`+AHUy9M^2k_KKl!KbFzV}%L*|*(lRs;mBbsr ziTvY~30u4HpAw-pZW@dV|0z-ZaXMw?$G(MmBlCN|Q@ek6Vmq1>KVoudL2k>^oL=V?4|?rbcTx$} z4T+?r-jU=nf4l$P5maIp573O~vf91-n(GcX9VHRjcA`BInV((Aw%UVp+1&Az**$T_ zv#ZW{SMkJ&`3&a2(g={*AtOn?V8XHR;-R{NNeE-(n#CI*}uLL*~t5k>`I!xAShQDiFkM0m+0X_WWGRc#cdM!`%% z%cC6PNC!Va5@-^hL^jZf-}NPb+*9(nhG@KFD9WBxUG|K;yM@hb?se4e%Xe)nNmLgE zKuLxB;#8W6NEe_QMMim1q*w4cZwEpcX!WvrmG8gzUiJR{G*u;*r|?X1^ID};M&A4> zEd}slK0BJe0$;;`$rHR6W0KGXFD9UmIV@h7={gW$N?QOXm5UehFvhit(g<5hr%~pv zK4oA6;N3bq)}a7g{*)S*M@_6I)tiE1RM6IxT?6S_sksQ+uP_@t;02>dQCfu7p@J=a z)L~lq{KBHM=UqB%_-M9#S^U^>;Y^Li#gkBe> ztu+)bUdSyI7Zy_^Y$=__WQ)Tik@Iey1?x}%@%<{&<%w!iy(y@<3fdZH%v~!r7eV_K zW`l>9U=%4z;fZ=za1_AOcOgDZ<&KPtN24$mtbnaUF;S%skwq?gQyJzDw~9NHO#?uP z(wi8K(^O_4<10?6$}E1(2u(R=RHlhCpC-Zd73e823Eqn_Nr=;n3Ftf1HPdwq7MdS1$l?q>8KvWE!(VLD`HX|{~ z1t6^oHUXzY#!w1a7a~NfoGeD{Ja8wgjN$)pOIf# z&6u?{Euo;Wh?(PQnM8q^nPCQL!U~Fu@e&m+<2mD;3)d`vdDo6DVpTO}Qq{AI|!PP`jrK{Jl5mI{5IK4=vA_gxY z-Pq#bO-Cx5k(eX`kX8kofYTx4D9R(Hj=LLcDM6fBSk9a`HrM7KNlBifHe?s=ZLZsM zyt?f~3s1EaOYS_*M;{aQ+A9_bm8INd6-Y@hRyQ7a;K1QN*+q@Xr6Hv+2byc&I99!} zeh=$E%F`KOW^gyQCVnhIhQF=VCMOHoC#$GDr?ARNL~d`a-PK$>yQrc)37WjKl2qOY zREeT1-}U5vXlJl1qd&v-p^pGNqRM1#uV?PHs0$ap|ND%WD5(^@NvI5_GMeC&7BLAh zWWaffgXjZwijdNa#_2^W72dpn=qNI8F&(LFMq*M6Kw1@S0#1jFql`Ub7$?q@7#$Fc zFr{H&Kuk_xouLYNahud{Oph^PrXw*CNAb-qq;=@>n;Ng2Rsd$qUbgtkK;#yQo(s{3 z4c%7DT*+OeHpLepkHZsyVICD>R2YIc{VId0j54?MViI7;fYWjjb*N4eTraFN$||(f zMKbj~iXj7itgmP8RyKI@J}gE(j)x;Nt&rn9eOlg4hOn=+S}iYW*$TKz0)r&Zn3x_l zR547+5C!Ox_%)a~!s&=4ic?1{FiJ84Amg6`!0J)V1N>2Aq6{uTL(-sZ(L=&tP_eDMCe`}S@>eDJ{Gnghp<9L&ziVQ=fw(sHT=OZyL~965f# zkl`$U!E)0_57#VR{LsU{{_)POi9O)tHRr+wpTF&YQ6Sl&@cf^J8jl=3)Y^PyBpU6_oa4 zJ(EEt{RfuxXDRVAb`ocg2#LfRE>>DOux~d!(wdW3K!s**Ui+HdDD^a{fU@Cu{lS`j zP04-G5_yM$BE00}o-(cJ$l-%Wj~r@iIXNKWLf*|^lqAie8aXtIL z3TD74%R4b2%tn8LtBSFJpm9ZXDXo+D(essNMP+VQmwhIIc9^8Wm z7~I_*26q|U-C>Xr+}+*XHCS+Wx8M+hyGzbI@2T&t`l`;U`MH1WUOl~J_v*gxLV91# zIT0ilNayKBQaX`?v*&5fLKzR>AEXbVebdKAl1gXP^W>4#x>mlJv$N&P)o9?&TFCjS zOelX9nwgVHAP}%XMg2M8I~%}`~t<|BbM40ts53Sz7gpV3}{X%cH6GO)HRo8@j^3+EMVRrv?~ z8B7J~wgSXRK{oxt#$}+y`UsB^Oa;J!le%WvBxcv* zrv>lgwI~_hbBms%-$jPyuaEn>@B2jCMjIthticHxWnA}tDtQx2ScLW()PYradeB*o z);Os^K$@@}Q4go7MF-87;YV}w1n+hDluPyr!hblF9AOr6_d1 zsApE?Q|Gx$E`@%_OQ~29DHKT?D|bh(#ugC7k-VUfeNr8)>2s#C{*YMGWbQQZBT;3Z zy^g-_aXFOoG*MRep#i#1n2P1L!^iW(9WVpP($rS@okV|mG4|+i^gyVuxA?Dxk{_L$ zgd>kFu7$YDzQSOWvLT^=O%jJcDMH5M-GN;*X>xkp-pY1uPMx3O@qh^n({hJO3Y9KY zMRUs-MyTzY~{x0$E17q`! zKRJ(9!wa%ppGgy$yv-*myTjSimXp6&_@dIfszV{w!`?1N)$rz8RGH%osfz$Plby+S8Rx>AajMebuOxWZBkgyZ@ zFYMUG>I!5{#&E?u{^05@77%%3&(CAUNmw24e=W5oy&84cIU5VIZ6%9quXmi}VONr_ z3`$oz-`Nb4S2G=#yTCb!(9BxI5dCVSp_~ptj7!gIjFP|#o3$<9_0|KL_e!cKX)*6Z zHFe;|kY*|-bW=qhN7oPpVTOj3lNH?=bJ?S2$$!Pk&d47qhY>LDfl{b{k*sKxHEFb- zES+z@86PKJkh=!8rpw74H8%R-z07W1;9);z^1klPC*DXter$Ab0&2Vd$S+tL;+L1* ze)*2ihng$NNLMu}s59p^--%*0IVWW8QCWPXE10|PKdwR?O;}8n?uKxi4C1i3$D#ho znog(|Yl6`|OcQ1ko$B8QJ#7#KYtw${vP$KNvpA-1;TBmS<)QLQ73VsFo=E`cPXl|+?O*^*>4(Nf@&hDcab zA5tgCVn8aAcDIKz_I(-ANw7W+b!$u#!ZGSY8R$^GuPXK3PZVdVPv8X-+M{S$QwjSp%3N!@{w(B`TrWLMIc%c(pB&~3`^e$Udgj6CXUro zIbj)0c2XX_%XnGTxfNj!Cb#NwWBlZQ18>&u8KaWs6qgApeDAG&U< z$rFH_lPL>lDe({)I0lHegdtj_K`ufqVB#m!3A*pztn@!_8`viWV^bPPYLzFPo? z8*fA>5jCZu0Ylj|&?_t>$C>auaGp4E;0D-TC2okGtl(k--bZqG{v~0=1nz9H*~2E^Knw<|dcf-8?Mv0-V>yNtKx>1I)!;+Oy;6`6+vpYG zCLHCuVHn6GM*@$Hp_u&;2ri-yu%Z;rrU0bcdfE?)n;#v^!<179=9-)U$x3Ee)%XM0 zHHOY}*v`tth`ZlHX7H{}E>Ymr#AM&SosQCF38YZG_EGQd&^uXd$rm+u%G(zP{>0$D za!b+ev%Nw1-!j|H0yM}a@Av15%c;v7-@_?aL^^u&QtWy zl@!{evq_HYZX=nW0Wz{El^COgsCc|MwmZ&>q9xlZP*>m~#|||$Z!6Pqj6N&P7LvGl z%Uypd3X=F0Jr0J0MmkdJb}|@FTMRo6Zeail^n%7_ZirpcOb2%jxMb26kZS38N+v8% zRyz#PAXZ5B9^dU}p%#XEOqn0&Jf@Es#458rnM1yyC%rfM%;x~Q#O3)Yi0!&6JA>Qe zOnd@~l$t#>!tqrB>AeUf=+Z?{p!(mDR^i{~%?RMTiq8=LGm98_oeg`<6^9-f|AFL}$1(gomEwIpyiUTZs1hkSLD zA0344G7*B2DaUtcuGZbgNgvvaA$d=opT@I*MAmD;31>CrsL<8_*QaAME>e_tA6cb# zDa1{#vlOgfuKHU?okkVP*q6s5jWN1RW53}o(m3vb!NMcmn4ggDpD4O&4TC@VwW^NH zC3$G27X1=gI@Fvkl{K#hDb8ZlRXE7Z{sxIg2#E|ViPRogBfjJVcV5@3OL$#agUBSJ z`^oo=v3_&reyb_583{IZ?v zUod%+dAef@VG!$efY7?vL|An!FQh%M2h>?g%}Xtu#J0RPzVppV3E|n=`tCa0bM1m_ z-}M9~kf^BT8(mz%767guX%A-?X@>Js$S_mwuO23jIqO#PBi|w;-(2wT7I40seS73w zUTkY+;ns&tJPDY{ZQCnCB{R!x_`|k0B$iKNBLf6O-&?+v*VP#|SpX+N*J&RIUwsB> z&UKpNZwNLDDMF zu&dRjP&H365nA|uN^P$ythAQCT$>#k^E8P5=dmx=p$#1A=k||0;z_m z!{$&GnwnMdUKI?*j~L>9bSX9%3ITsP6PvT?-&F~)e!X~N=>)j;?Bph@D5JsU3CQUjAif@Z-qU)H!589@4fTkOyyx^CZcfAxl^XBM0>^xs*lZm4r+sn%d0Bg!PS%v|~S`@BUAs+mdpz zkE%kK8x?Vc35&XpfsgX3lv?hz;w15gruK`?7e7ur zD^rGv^6dB9ml=A}*)mHeTZspr`;29Qi1{xqV)eaN*g*y3CMK%p?w2Zyl&@W&AZu{( zae1mX(x;>m?yuY3EyluGtCMs#+k;XvFqbMp7BB8oMb-KT1XCGb5Tl{+F!^mijm}JD zv3R>cJ&()|;tBj)kYeL`hR27p!rL>#W)Op?!5zrkE8yR`A>pr?bCT__p`AziR7L~h z!;F<+#jj+bq0vaEYFaKAMShKG9A^k!Fvh{@;p{3KXqi|GQ4AUHR@5*18Y{K`?DYHn z2wc#lvV!wt3^H_%CmwvbD2#<*{ez3mP91)igV6+PYHPbCqxM@;2u!p5IbE+R6l5a55;o_T8CZ07 zT*No0f%i+As;VQ>@h`daITS!qJ{G=VHS;D6Eh+ z+wZf_-EC_)ILpSJP+I*iNnmbr0_++yHBs=>9;T0+C5n5PD@66e!^6#5snD|KRV6KP zhXw)VT&Pj*B-7bHrE9W97w#IMP1Q@%L1EdMNO0-FU7DEK@L4;6S&P{2DWrKLA%{(9 zU;x#u*_Xg#bmy+CvCgGMgiihS3Iyq!o9!Cnz=}wwa_oE+XPWG8r$T;T0QM z=*`Kx6{)?dfpXuD`s>j^-%m}0s%!AdM~B}r{UxY}b&Q~m*41^@v@yiE%n^2w`EAfH z)E>-nzss{=A4Awnn@j7PEEM@o&haMQ1sile5m9YlYwbHpNb&Z(v^ER_n4uF5Ee7&cF&Pa|C7jk#wlVbCiB%9|NF(^Lk ziPrPlr=Y>TeJ2LBgdl4ztrn^Gl$(b4S|Y>KyL_*F!n&o5ZIPzucy6&u)-w1Tx+AX$ z#m!=*3}!;n&BQJEPT$|nwGAQ_v4~7$MtMuM6n1Wz{Y^^ni{s)2+{%0=EkcqOaG~#8 zTr-fV96wbJqDbn|oOkn2oY8%}CVNfLm6?%X;w054?lR?@0XrRUWxlV63S6Ff8HW4Z zkLWI`y+#7>|buq*gUHXWKb_& zm%xhdoV6n?Wl?dFQ&AEbOg&?!>vDz)F%i*;`_e>17$r&svdO2ke#w-OdVq2nEt!c5 zKC#|e^IMxPCjFAC>&B9jW}cdXW^h8PA@sf?97zX?mjc!y3YJaPg9;~2*0^SsS3@$A zCOLI4p%V)0920FE()xk<1O3u6PT|$b5NAZimrz=NnJys6Tp`XF|JrM;Tdt&V=_I^L zr5?#xF{A9;GZEWSvwfo%X0qr2YS;QPtVCH1Pf6!WzZ)unQo;IoTd_wMT(3j@xGQOJ$M zszH5A4e&dRMb=qROjs%x$!XStt(@p{{cn@UjKrzNIPMj(Efc!~gG>$e6}RMR?$*UL zOFWjKS8bJlDJ`!L#M^*t8C4g*O-0h68776vO%HxssoYeQl%0Ehuc;P|nPeo;8^K}T z24a0pYb{BI)~_v-O6M@@*@M72r|w~n1OPwH>ghzYWbCEx8XJqsM!nJSA4#e9*m)mW z1#`JLn$^Swm&#bPNyf@@@j7I!W1B+6{ZmH8QX z>(WcX*YotCei|}0nBTFs0$VKo-SWrNd(r(>c*tM>u^c{m8Th1J=fKBb@F4S?C1_Sr z)(*Q9{&BljY4A>351Kvk0 zoM>l@KNhHcRoG8Qmipvmzqw-&lNjD!->$ypRk#X!mRKx#aqq>K?2uhH5alvX_P5~J zU}Qm^ejXc2TVS5B1oAm(Ml4-xZ^^M$&)EOeU5}zhdV*Lcp&L+NfMN=PsW;OhY{ODk zrMaPoNO$gVYLzS@pUxsxAfvZ;l7w6?;)Vy0`1&rAnCcsr{+U|uR;av?k;t8Np$|1^@655PwC4HQd~CIyy_JOg zq9-fsahe*Y#VKk;BFn3g?s+!Zs|zXaQKc@aom z5?t>|5|F+H>}HhytE8M=UY3%YKhNbMUhmhU2T0j znut~J55b)nfxaHMmlh?7%70d@8>%S-)G8X~6LH&b2EQ!V`F;QHet*2GNpyK!LRpfe zIm`MLbwH^Ex6B*=PnC2=-YOzN|x^+wmUJaXv0F61&>jCyKCz*Hz4LQ&mj+<5_PfLak% zEt^n{Ig$G+ftamAmpbk+(g54%AV$4g7IyK8PFbt5D*%3<&G}nFqE7|eDTz7(kqHsn z*aqQ3&dTzJvpm6nrhQ^*^MtIl2}TOtL~(02@Lx-~SwXx;w0;Tr0mc`@`-j*iY?el+ zB^qLmwia^7MeCKwIL&kn+5$+i^PlSV5feQpG&x7~^d?+h$EC$}JVZ4d3Mo`E6O@fx zIxBK-q$GIY|FAK?CgUAXAQhQG9+-V%wc?-44LAFgZOZ75TFsfkaSlmRH`KJVI0shTxz zUmWCOvW}P~`IihzuUGrE7q4W+B*7}!I%RevGeh;Iy!Av3&K?OFH!iT~4vGOU@W8;Z zRw|1DnjU0gJEtpit6Z%RllXJ|sX5ht(by7kN+obx-rTsFg&~P$wUHB@oxQm$D>xOU zC-6(hTgd_F&mom*RQ#rO*zKt0-s87w==1(;KU8!YTq~M^$owdm+rG2+Ad6gG zvx-=+@(Tqk)c)b{{yfb?!&=hnh(%u@oau11Gjz^kaa!!4&fe#aeVp5NW#eU2a|Ga@ ze{o7OD46YCwlbxTaN3YLU{|y%;h0zk7%&PcKv~C|gN?T-40MUa!0FjjLy}9!%y7ug ziDSf2RbQ>$S`x`(<2tV<8yb{1ha*5R&2U>P2(C;(u-Z5JD?^;_^<@P*uB?+T2x?Gb zI&Fp6yu)(gAbR1@?b?|K&J-@MkrFdvwG>v5WkTzXtSJ8a^Jfi;G*8C^qFE(%L^^sqtxKEyrPWfr`;{ZA zQUiGkRVb5Hm%@;56hk}MaOJTdflj`kf%gOBZ4i-gQO5@GE{hLHA#{WTj zW~!!)0f*!(*rf1t^zb9a^OCa;s)|!CE{c|GPLH+N9We+eG}G2l0!ySa0%@F1O#LlV ziH_D-&aA`PL#S0mM#Z=jto6x$e6_hL5F3j*1;>my+I(Ik46{;-G(RS&fYRA$Pf{O2 zAorpa1qK6)*AV2Ie4GnRVpmLC7#i8dj>hpFkCmY^$-=bjV=%l9iWVFGMx`F1ztx=p z=(UNN$I`&g`dic=TNKuZBwo|ZT=0s3IQg4;>=BYGVb-^G}t8SniHEWi!*EO`cx*hc8 z51JTQ(I5$2<{vE0bkyw>K!1%6G2a!?Mmk7ek0Zx9#(UNM=Tgi!e-;;{!mG--uK~80 z+so(SRnE^qlwVNkq;QTl3clK{<@)sy6A?-;)Sc|R2`7qfeEJHT6|f0Pw4efS%A(a) zvE%IJuohH&bqpk~<*7ho3NSV4+6T2!$tWjU6%v&G8pmw%NmW{6bpaa2E(5VA3aCPHkr&U&>6tIQ)JPR%4IlbBptr*0a zrIL+FIg|TV2H5*EzWkj;#OKb)gAeFBc+K+HTc~rTcY{)U+rBA(gJz$QNKgU8k`@O=%|=E@UUbNe<@=||aqMZOFg0vQ0O`hXtZ%C3arQbN)`#!*)(elLZ6^n62WlK3t$+z? z7$I7IZ&3P4PpuBp`Y{uZCRcF1ceGw=rh5a8l81pcMuYq;?qz|tr%T$TDci`nCFX^r zwB0Zm$-?H@{iR-Zj7g@_u^gNnENVUw+H^@tNog8W!h*cCZeIxHl~C@mFD5GcLgsam zjTT%yn;7(tqFBpYC29Ja*|_tI)y*IZ9m;KPOMS!|{dIabi(Y~fch4acMnuTl0w0PY z5#5Z~AZZV-7&IgH(-a;-`{=MF({Jir^gEu_qa*h{AkB$7Rhzd z()m>|;YgG7r{OyO8t7YgY==C49Nr8o86@BtU0?c$kuyjudB-!O=MNc%18ctO$#coH zO$|zU@E>hbJ*9dRGj7ZTOUUYSks(f@N;=?WV`pbOZ#lwufVJ5uEB7$T$TV6p3D{b9 zvh#c2+P5p*yiRP&KORE6AG@bujk}FL%4N_hIygpBHonR851dDRMfr&HhP~sGCl>&C%ZuOuMC>r9w>xS*xQXcAr zujxFRwL{QoObQ*d&sdAp0tCScwX%36rSjoY-ooB`_;qn?NzU(0%Km{%Q>TpcSy z6J6M>4{7++K#Fdf2XS5=Q@TR@Ah5d#+v6K_r+Ov#RQFqI%?=QNmET-CpS%y%3+UO? zFl`)44xyTy%~%$}$E}I3FG>D@pY9xUTyw!cs?Ii!>p7vA7*gDo*{V8TSDaE;;IF8^Zhoz>?6=YzGaUZ&Xwuvs-BBW(SJWZLK)2 zUiB7Ao^m^Mx%36v3f6b_n9VyEOD+8%?1>fS~0QhyD z1SLHCWHQXM#E@z*0wb)Nkb#IYcG4-4rT+7SL zjZOz|&$o2(90qPF*^r&-i{~BlUxy?iwhrKy5G_czsCrgUhFOnCcSYiNpWSl{8V(uj z-vUfsucb~ge1fIS+O7MjVN^Y16`+d$5(`7IIg^ni6-U8HS{lo&*&ZAXG=gu4PjO9- z*_DUdSa%?}Si}5F*OOZcYsI9o!b-^A^&!a*sCK-%VCP_3c=yU8dOKDGDiqBd76y8M zfCT|}+IXO3L&erf#mNFD%>F*tG0#sRs4Oz=zYAa7`H4`r$i67Ndj3k>2pB4OL3Kss zICEa)(VQ5J6B-~j{FwNc54JLf1LRT-ywSGZH+ZgbMvvbtc%ilfs`sH=INN~z=R0P2)&U5nkf{n zL%$Y<*U^tA`O$&endT4c7!g}rTR%U)#r}jB9DsS*e1{tB$QUkHti>o zcw;kl+YYcJ)H-q-XXV(<#^Ud;Ki!Z1ucwCGT%3m%mTj%g`rvEowY3C8F-F&wx{5FL z+?bhW2i?IO4P@)g{?8L#zkg3Kb=pmzt;yN$g@|Ut{95E93`)@B1t*#Af<(SMNQp=~@kA;p840Es`bB*DIoCaCRPOd`E(*x z@aAg3PhL(4JcOf01H`<^{DiOt`jsu3&f?#7IPLTI9#J*tgcs*F{SziUz#E^Y=Asj) zDLPBuyicE0-t>OaCdL~d|7)AM)3a5Y3$-kgM%Y%wS-S3K0TRag-4&v-NL%+X+R3;;I>Dx!st zTp0zSBo*y+<=(4Yl{l4?_#E)eO+%lxuT$+ZeOPIFlNs?WlEyKHfBZne#j;MUiq7xL zl1m|fgqe3Cw1CvlBom98qgGVGx|ao?Aotn0bDEg{;ao1F6Zqu|D~?a=~MRgx=-D}e!Xp{ z*QovGUgCt!LD%X-X3*EQ4@(3t;b7Hh)$&lxH8buSk&W{5AMQ%anwN1;Cx`dv$fti0 z$22suvBhealvIM#vr)%BZvaUA(PT+ero|&H%hNh?MA!(L-MIRYzkM}teXBj8lM3u!Tv^rf21gMLwYu*bE;~q zi*>7)`P9L&Uxh!`Eb|z38V^<*Oo=NiD*94S#$AtWDq-5hKfkEre4c6K9fuDyswZid zD~DnliO@$5jm#t+S7j-6AsXG&4OSQT?1-dTQR?`Ll1{>|iXy_-8ymk^$3qvaUzJyg z$jVEUh}Eja_ePv3K%WzaIiqOqWxU$sdfGavR;H?^*6Dt_jIU)Q9(N?HKvxq>?k6%> zrf~XAvTC7Eiu{kdOdUPaFqNqAYoh?O#Uc0SzWi)@Nt=BaCpkS6BY+rAinvr|;XL!T z9e-sO`yRtT_|N(Da#`OK_fOSg?Vu`bQg&B*RFDRC zbST;9j$DSBD4l1ZSW~22!un6*?D-#}H7-)YYkL%Zu9=`M954wW9nc*NSLemgi6|1^!;-__D?<1+`pX$0-!(fG#I1ybW+YBWDyFXrsttO2NivvBim)* zGV9QhluEHTSH>rV8Jm;~x(iFTg$p~m6AsDeQ(6C(Wy+Z@YI0DF->bMVP>>I=BuBoh zMPAzC+$N79swDKsaw|27`g(t~VNhO0{z--wVcq>MSo ztP=|;;4NZB(>MaCDtMzEWH$*K6<^AqPZ5(&7O$ju8z0C9`NOBftNx59XaNUmHqfKA zAj!$wFW=?M75f8%VCxz00jT;3sM8b!0Ij}$qL)0F3Dbekq?_pGZWaV?Yr6{)Q&a4+8aT99Jsw$oHk%3!H> zrvB%aJ*BA=Rb^$BT;;@14lT1FaS%M^+ErY0>D3*ObhydbL_sNFw$YBJe(d(%;77b7 z#)uYwOP~lFJBp&Ze~P4d$gVzsN4@MjSH*T!q!L#hLEj9@Dq2K${A%uUiG13IV~=Tn z@MnmM?&(U+#Mrdv5^u;mC+CrK0XMQmkF}Ta&i{CItW<#WXhuuJurX=+du?$bW;0CH z3D+TOgz9K!Rh2D0fcZgo4P7yXqDtF7=ID|{(=Dfounndni-+BaE`R(wc1i}HC-AR% zKw0d~UY^TKd4EX2&r#*FTa)z?&7E0wjRK_KWL@dn2&PIoMLJ@TraAGj*gTd3X(bPNa;j7(&27WdH zRMDoli&(o@ET!_IN3z@7k~HCRG#IkYg9H8G@H{~~uciAN*FvR%DyA}iGxp&ga?8kh z|4|@{cd-lBK6ON+SuI}TTd|wMRIMHDhL`Ec)wpM=nI($R>~G=_jrGC{Ov^v;Vq;& zZvDI-c)s;{Joh=RY5Py@acsT>BBxF}%KUe=V*ieS-QnLK?rXoFjg>YV%Qct3ZWYz- z&LG-hmWHk|)#0spx_OQed?6_ezv*)I_5Sb}?FP;feXHHt+sF0?Jb`z6`xi#NX}{a4 z?oQ8dWu^XD8XX>RnEh;uxxtH&ZzU1|FQ^&UziUlL;l=}{3ORfb4*x%S*vkIGLe$e` zX6>uBEx%VJWaM30s}+a0)rJ+2&v|--sovs;4@gsUd7fI{<8qVV;SC~7HthS$ZLN~C zYT@Bhxl~tay;5he!ANKQ;J$*|v4@n49sL#<$gEkd-{K-^ zv;9kZ15&%-iY<`d-_c?#{QR^}d3+4_{QPtOq^hpF)%*4jG(-(;0myxLc=Y}0u}4~} zwFCUXNrHu?Rjc1_alY=x6bQ)33rUURYM8`nIU5L%;B4&QyMp2P7<8?$`>m+fxX zm61IEM(YKdzY)>VXUXY2;4|&tN6uIIqF)$wuiAek%b(p&DdzA+ytf+qUJo_wkFX{{ zif5_aFNGj_09CcW2>G4*x=fy3s=wjZU!Ccc`#sNm2Q`5?JK|SB-vx6g8X!R9{cnm! zpcQ>T&eJ~Q^=9Y&8yvvH`BsuMUz2H>Vh{p)2H)ZMXRL3-D}rJOM%j~JzrS7giUyS` zW@Kh0G*}%c=J0#u@+K`cn2tbjVcTypS!!x(#kzbN3g{~}aiYz>T;ufhYjK$MDF})Y zkCZ`IiOWiBVlENh2gOrJYA^ z%AndshqZafjd*K6!h)+zIl;(TY1!u=i9%K#!iI7JSb90c)Ip;uSmMk*5me}qjR1Mf zpPu}B-a=bBsKkX2#iP<~VMZr7Sv*gGO)_Nw0|yvzoJ=ZhFvhG_jv?8eePH~JWxfFeEM8i^L@<nP+Vjx{2QtO^^^gta?4OJtO%2?hq_UCx6 zpxxw>&emfArc!xd3Bb!9z1r$<4#jiEI>C{gx^$uu=uM$Nw0e5_1z&%CoFQkP3xKJ# zQ+A17LL-}WblMIB+Wy@A2C5n?BP9aKJg9w+NaGp^BxFcgFs@fC$?~73C?hiiqT``<7C>tLi7Y{@0d0pu#914YB zu&w(Aet_uLsYOo(GonjDsODG(qu-y?zRSwWLLaTWpDLQ0W%KwGPxX~^*(+z>73sCP zjj6_?5b(Ka+Ly^5lP73!LTdyVs}T{}d+o~dBcPKI3WD#Nj~C`*8HBZdOfN#TMWN9J z=;-J~DbY6wAInO1(bDsLZ^i^4x1wka+%{U9VY;Rpahc&PUS3Y?dN@5DKGvFkvJk)i zjbEamFxk!(B;s=Lg}?`Axq}gWWL6y4o2~j_++?^NeuryXvz8iyVkhubJoih1uMZb~ zFVe!1k+Don^r0?pD0(dpg7vK=WA1Y%uNw&8#lN9-Pjntll6${l=@CAboJ%_xjju*& zYqpxc#`^g@8t@*OxpAQF&+U&m_eeIM&pHi~2>%$l%COAkb@}V$>?|v( z0Vf2#Q&%2jZB4{wx7uKPq^2aaR$(fQjZ}T&a~XrriYW)7qlYjy@Ff4F5}+ugzPp4p z>Aornu-3LpM?X*h+X)i>bp^AG0Eh;rlS%ZLPPE!XgoynP5~Q6Rb}nH)P6Lj@LpJN3 z2fpTg_aAt<=Mi^WSHAdYe&;cRa45Y_UbipjTQ+hxJGTcE09ExnB}+?s5CxVNTRh#Q zq}r~V$qe=hn4F6olJ&MD531a^*i`C{T}jKR@$rfL=c@v7{#N_fvjruKncVkTarFD+ z1q7^M6oGUi*K8Jtw;#>#w#(JL&L?KpvY$1vz1*4~FV=3Ku#}?lSp!pZ%m;REgJ)2M zI6hEm66hqXLeQwiUcCIvhh>(HB$cHp}6z1 zE4~~*Te9j~`+_Cr)mtp3+Rh!4?5cTF4aGqDmU5B}tsqD(|DL-DwPf22ZRV5UE_G<; z(ke1(C22_7%|T{7V^Q*7cV8U)Rq?tGV=R9l=CjP0!xzoG|HA?ZqCG+rHXYqwmABVS z2K%mq7Ov$U>E(`UbfK9OCw3fZsX@1>?ntO5BUMZ_Jp}f&?=-GnxymzbSCwQ{4S2-V zG$YiKoN+rd;b;|jnThs#*WF&l0_`+r6SQ++5<}^^crlQ>IE6I~cTGS$%*`EFQQF%6 zQQJ}+m2ZbVMkPGCGa`J?eUO);;`H$SseDwFPg$u6w~oXFqTN~+hxZL6+lPKO0IhOT@ z&nACyZwr9${?_L3kHnBn^eg5Gb%UpD&O*gUT=UP{$EX_J%>WXY5QH;j7&^0lGc6yD z#p>0rI6UgfT9wx8lb4al#p?Bh;k(Q0!?{&_3^A|EvW8Z{afIU9o@A8A#k$7Lygbhu zDfmjX)|Rk&XQ-qE4EwVzX=pBB}6lcM$_3eQAoqW%}LdJz<&W57uDhMFfmb zT+bdI<}o`ibZDrmJ|gto8g2Y7x6k7K-e9v7k@T&e&T+3;b~O?azH3%^v(?-d1R~`k zRdv@li1yfBU#c=SF=6R^3WhB~LdSq$#c3^Ks+b#w*qE672g7$iy*x@6F$xKhMn$xe zd9eNIzC~-b*-Kc?;dd9)!b@8%a(!Gk1Z}q36Fp>uJl|Ueq@&Mfe;KU*_>>Mq+8#xc zwot0jeNw4I8}^4OGhL6%{&%~rrI`Nj815}ZYO8+*S?!iHqymL_(qH9KskL|8w*BzJ zIC*)WKVC0K?2pGM#+Pe#I_;j1{tlaDQ3?3j@(0-9&OUz_{_C*2-W%nzA71&@;VI9T zlhR_e3%O5ht=9i#GI+I1Ml1-PY`f@0f5ly?)!Aesher6aFj&-VYHGs9&OY~tiHWX{ zMu$S!{RF!uy1^3y-TL?|EcPv%$D`BpGWYReiR{^!4?Na%T#fDRas~@5Q0A?_^irx&Y{0 z1I0=9o=ip{sHd9B^tHGzm0Xp7sUzq+c!Mfkk|;O7^#naxwClRWhS#;+;_mu zTD@#vF2Px>U!^ubpE)ql|C{cDmi;xDl~hWG4l2nd1!tjkLu&C=7cus2>T$}a!TA-8pOFhG%ZeWToL30Bb zi=Fp|MoN&#bqp=6D{Lic{ME4T;6PM76l$T}74K{)>)&Hfo`Ac}bL3$oiLB1wad$=F zd)xQ-wl9>@JNztQDbtbI_q7>+k3R|z%k~@efGR%Eb3W@mDt%B+hBx@Wa_M*MU}#U# zNUZ%5;wHDaQ*{t>a0C|3+YI}^AOuzBx3eZXEyd#%W~bu@ zUgz4(w4|V@kz3@O5n%dN9R!o^_xdozIUR2DsK3EwzvXi&*!<5)mgl@HsOLd3JY~OB z@%>oo`+BzzCYOr;ekjqO?$?XD@9(w@c-I2eWWse7AT7@}j{%5r7_K*{Ok;~t= zi;Yf%726_dZs$D-SAgHdXey<2oY~Q;w_ikDJ8*s zbulrRJ1tH1;s?c{51>9;92%W^bQX*MXgvAl_p6T!!A^dUGq;5p?mag9>@Zf9IRpqJ zuBZ2JuklpF+Qa4dyzi~`PJix>g`alOoyk3_G;09#(^Hwew;_0j5)%-8$sx$`udh~H z*XE6Z5kRo+_+%O<$OpM~;Ic1*A{Fvob)69mi_&ZO#=#Gm=whWFzb3GHUn-)l*b^I% zr^2MOFLoktq#IWxti+Buh+iC2q#k2Urb`@GFt%qKq)U{B8MR0!2pq{dg(91)9M>J1 zAoy)84MxF@p+g^} zk%nf3Qn40|?KY5Jm1if5+JXFS84n?euyO^~!OKIfB_f+}mvk9ZZ0bMt?p&e8^+PiEw}x``HF|o-)ora z#c7+m+OU7WV!-XbzPMbY3&&%cwfo}Gj5pT>pl70|Wsn(MZvUZWR7G~7{?mD3IIzjG z;S0LR-M)vzb}Ry25lKW&`XBvKER5|g-ch}89wH_aqcUe}gRx@zAH=jP)jHspIVItD zY%S>h(nxgSHyJrOB(zOGkm7(;4-zN=_09E ztL#JxB4YQbn`ltLffar) z0uP&)Ti}iJ2G%*5zLfQv_2#}QV&{#K40a2#=a8lF$dk-uY$Bqls`jX;>2a68&5GGP zyTe&Mq9bqxDwNW(r2J^-GWsjod3kXp35>eU0Z@1RjWgL1;o;%2k^RmQk&*o*VjDx> z@+lrHpg;XNZ`aSx*dpZ#=R^^EwS@tCXK{^>kCnc-%rLbX@vO|0ft z?i^s_ZQ@j<`a2i%e6(q3=-Dy~T+Y@qfNW!X-s?@vfjSh?XGY>-YP)PPF zJX&R;GEAJ?I6;CTnAnCY;*%74UiKsweJ<7*qAZbWKv|?6tr+2dkygJLpPT#3{MO?4 z2Tb=UsUVnRD~3{K(UE_?gZNB64{a&lxK!o!QLz_Np^3ZuPBx`$?Bi?mNf_C5z*1$$ z+D1PrRA_-<`1tI{u3@wiRRwi@OJTKQYs%UQhwG_YznK~-$Wj_dI(#({X0lM|s1|d8 z=q1C{@9t-rBZSThWqlvUZR?q;fm5=m7L*NFUW|g!loxNJrlu^eqZLBjVwP$hmRWvM zSui$3np-gw8T>di6LMnuv>1xX5RIDW{-dnInUU~C!4ki$2_f`5t58C5YVUGnBj1{) zMb*M&jnd~Nptn~Ji3Tyb2{IB()Jrx$b*m$Rn!1k6*CT#uDVBIpK zrR{M$AoTo+)?sES5Rk7K@bQ3ppOsGW+5PVa8r_;UJ`vkV{r~$;KGkY+#Wa z8GL~J=u6Dub(KC>RFY(0tpEZsam^abZyv}dsj4z@bin|YyLkBu_J?ETWn~5Z_-?O$ zd>R$=#7t8uq9!!zEfYFpwOeK=^o|npdo^-2P*t5TiW80{k+Lr;x-Q2Rpe8oFT!|yrl&J@ zffx8JthwB(THd%?f|sUl@IybAfBl-W%L^or*HU;NQI)c9;)FHUi)PKAha61 zT5G`XH7h*Sc)8ksdwF|gI@N&u|IqeUQE@dn1 zPlo0j(UpH`hb* zP^F4ZdNV_b_}#IU(TKzILy~fSF<@7){sLqEmZpa#nBJb#qxurg`gcKn+Bs#|bEUAv zFTXiuQHql2X6!&~ZkzGs&bZMa7d426^{623?z#Nz*8?_%6cOFbVVE?gWoRlD{)+ft z&xTKdlqQ+#IU$^KCx&V4A1JO(7}9d{mlsTs7-BWP&D(ia1)-OkTw&UjFccAFJk$v( zlSmTODy|@tg*zrM)FHjnI9k?XOOne9H7{~xAJmilE2A^8s5W)ML<;S6Y)j$fmEkA; z*SbD_U70yN18-Q`(@$8ZzdjdP8wlx3GU&k%J0Q7 zShLa)wkGf?WWhbSf605t4*oit$A&&P{1|zM+ZP3wQOt6k@+QaYLIf1!IIA`>Iq4r= zv(yid(n>642eYER+$>|R3IADoZ)f>xTT95Z)GBobo$zAj9L#|pt5G5WB9&9v0>Vvd zkaO8>3Ao?$??ZoVTtc3;rQivFCJ*Yucdf>+Xs70;hE~TrIMmaP(yilCVxBiyZ{_>k zlWfe#mgJ633TY@gaK0W#vUI`WHy}?T>JUV^Qa)*h6Jyq6+|9>=HtQ_0NTgN#&PW;P zBO)WNh|JUbxqq01x8S0kbRyeaqA(I~9x8?e6w{kwGl^_^I{P!LRNg#n^UV;L*B|#t%K_npM4KPc8z|G+zER*e8r$ z@L+u=b_cIN*SlU=c!Vez*R8<&x%wxM-DVQ7BiAha@}a5EY5i7CC+=Y(KBpO%8cLjT zm4~reEVG?|db3EzRUod=mXMFTzqG+5%-mfj2}Aw>9eyWV<_U6n2(_?HY!YGuu;g1r zB@0p$BC6vVQkcdCU63`bN{fm(PopWwdd&rHIv+7E4++IKq8NSWYKsNnm3m(cRH5A> zV3y4-)J8IXe<*~j#=!JoRz)EJHL=U6-6z&42cT=(DpEpKO=7ePTS+C;mwvERLo&_| z?$6BVf~u%VKmNzcM5r!!A~Axtr&B&QPR;tkP!cs!hdYonNP}7R;Nmf0@(9_hoC-lK znRR}e#O>RmAvjt6#A?K#))^r!ODJBR{aVk5YtT}E$3hx9-y0Z5#z7|e{8FxRe z$Qi{^KgVA%3M-CjAlHMGd7VPw;!|H{0+HrlC7B9UfcK{L`cOM)ulne546o>de>rcU z1*9BZP=^oOny2LEEY>KfMy|4)_6dt9wW}VOVqv-%N=9MZd)Um`dsuO)J?eMCtT;BC zFaOAa*?fD5OmS}5C_?}uaIf-01#N=z{b`2wRZ&zS7N#i`7c{!Xt)VLjVUbP_#4HA_ zegt0kt^72*Ru@8E_vF~E9!48mn#ZRG6ePw~2n0#3if*TOCzlhMp&hiqZN@LZIJhfO zx5K&5v=MK&l0i%BZqQsRfXU_$7EH&>e&m*wy$2VwLWpWpidHlF$Tk;^peu@ab2N8Uj4b~lWZDkzStJ9W}l zi!K!jtNsL1kc0II5ZT}W%+<_iZy4zMs<~CXN@FgBFv;$|U%Ol50gm8%X)t6CGGp33 zk8*;Nh#UKA)=THB-4s0cc!x2Gh6kr=t&C6MG&#Ml-=t(*#mx>0?K;^m|9F^3PVLrj zv@VhStQ$1;oh?~!Hpj@>vMrH}^}yT(;=YApvHTAmDp81?^I}kZ`iV;qV>fWaEu^BgkABjRp z#1f}f7+fj3G$D>hW#O~0<6Z)>fG+t=xIz#FU4JQb zfu-`IxYA*C?*cOTZEYRM!- zrE;Ca_GaCjGR)ct`hy>=)d&!TU^ej^mG<;;2D6Yj%uSg^wAjH3m}<;&!I-u$O8#ve{=JA&EAd~ks*U|&X{OE_QzHvsgz|UkQkO%q=`qEM5nEK5or)7!##5AwLm2kjKMy&dV$b4| z&VDF7r0TO7Pj8sFVx=jl>eC9)h%!{~kIQ;<2@zX^d#k8{y717@Qy*(AtTlgHLkaCc z9iXSC%szCJ&q0hwgl832rJ&6If=H5j&8*#y-Cw7BUDI(Pny7?3MAhN`3x!{2MWfa6 zK>BeeK`YCpnHlrTI@Q(J%y7HuY%gTM$fXk=0|LZf$(1teQY+-hVr2eG)V}&~`SWOU zkJD2_1oR0bAw*OyP_R`ptELytCUIig91qnqzd1t4r|txa)&1Wc>Dv`zQ`70LKZq^3 zY?V0J{p#LnS~#+lWAD%!2(+3>Dp#p1MyI+GdJn`9u4|!ytvKaxOi)?v-F{rfqLKSG zU~5l&ZFqNuPTX&L35x*tqdc4&zq}af2AeXl!9RbfkDJz+()3Q&^A;OmS;1cA%`B7+1>VJ;sG^ENR)4frTb{w|Jz4ku*hM4%ex}!y{u{3pr-p z2Ds`q;G*}@3&$5o-YaLk^R3VRr@+z`1m*x*%7`Suped++ zu$ghHzN(|;Q9(ER>b-obcsDZJ`>y9nP@2KA6t^@#CYdNYQbewx=qJ_Tw8p(20@HKC zVSUMK2t%;6*GPO)Y4Hk*<=#frWEuLXnM)s|!y(qZj^i;a+b8>B0Gb1TS_Nu`(imD&T3g!H+da1DSRwJ4p(LZznh;b<_O|e& zhzX3?0EZ8d2+Yy|F0@SbZBi5XD;)XU!F*b${S2~-gPIY2Iys#GUlrJ|SE?dV5`SK< z|L)f90Dj4?<0VLxPPi7ij)i>cJ^!9eOcal&Or~>4KnI}Dq7UTo3W5okdWK|h_mubq z-2M=7TFl%5Grm(%{K@10UgOR?oef0VUr}Hb{`p3T;pFJ_>7d;v zFX$T?cf~v*?Plm%UN*4Jz}l`-Q4$HbzBIz4)P2VBnFF?|C_Fx8wm8Qp%~YS!T+CRW zBiuMdT& zH5(nS_vedF>#YtOUU$bSNkQfhdC7vV6W^QMFo;%Shf@(rHk?lMsXRP9+;_P>5aF`ac$Mkz_;q z#}HUEYDl-uZ(!bctrz<9VrL+9k7;(0v+5|pEi0thwQe@E0#FI96jMa!vO+2 zy8aXs>L-1~cgaTSLQevWXQty#8^TR-~mSnCPD{yeNy|k|ZztmyW90uhEph zs%TNjPH*9VOSerAh6L}Jc>`U_eq2c#jZV)cwd0Ig-xW`BmUaL7cN|!1ximkPuK#+Z z4VF%2FsyQ%aOSePF)=F@$qZY-$hH$@r@#&nJSPA5B+X($`5zsD=Vp;2DLQ&h6Q7jJb=O(n>|1;~I^Vr$;%Kzx;0~0!NZhU&X0g z+>$e$k)#Rc83WJn$NB6X#R_!$4}ho=^3vn8Oz`$YEiHdcAXR(q#K>i0?FGrv~=C~_^TSkJ&))>ff zUHenTZVPto<@2~3S;CKOdqWm~Y=&0nT2DVr@cP0TmJUi3d z6BD69ad3+qCUQDtNou#tz>M63SI7nmQ%#9Wb%LFsseQ}^gNlPpF4m+@e5WWmNs=Lx zOQ^HZV95xBnC($ts)=}OC$m34D$Y2QZ6 zA3+B?tKm*9dRq7Iiw%MQ#}J{7|7;H$EH-2F7CA)nsh{pjfZSGUP|v@V`vKNI#U26He3Q7D#xt3oD~ z*_~2eT^%>aGs=8*hC;EG2wBr(Bo^f-z}jc#rgm8F{mKW2?6!nkIFwM*7;DdDG+STW zr@cQ-+BL}M=5Sf7XO@)#p)j3qt7kH)OhdtA6yi~7+&%;vWMG^COe=IF@BuKWo!N!s zYqtcZ!(O+4dF(4CO!7^`5k1Du(%Tl?c{;B`+)va+1ZT|pAFfe^U++fp!Ax2$gA_e9 zvKb~)wJM`zO`9{%1L$&s5$*atZx%4rR4$d^6HuYCalIz?X49+XO0ef*m1Y`s3t%5< z=lgu-b-9C(;a)p6Jw9x@V0DllcGiG+-yIbj`E3uw$7J0c>kZ%>TP!!I8P4t+OvhV( zY`|ckGsiHZGnn^>dbK`cfFIQ=!u-Q#zM}}|jI`r1r-FLZd`dDwS z9lI6&0TQY;@>8$_UGhE?x}fX(Rx}PxAR&K6o(`T{RCE|CBXvx^z>_{yOlD=0>{M62 zt|Vhqtl=7@5@mM}|DH^Ke^W_6QTTEQP`z>CLBqP)1rYbNtcpESyem6G_xVgAFW38v z4VHl=$qyeIiAU#>MMz1Bb6IdUiEi}2>E(*afr`^0aU6s;487O}o}$vu6S@&Q^m_aI zZgp(1gzEH_i7GFIjGL5qm?z0b><_7A@!Qw*Oold2y4<{Q1E-XJTqZ8k{k}v3MQX@W zt~1}D{N!Op>8aRFs%4;pOvt{GV^PUIft$O6H2;3&t8}!smk2yng*hF?ZPEclOVY3? z;Wb(9QU|tSCmB7(kK5E$LnUtA1tlG)bzl#ub&iZ)8WGnqwF; zxncp6_wA`^^{*VZ+t_~yS(JjZ9yv`4dgpbOti(<@&x?eD%0>uMl{OjCN`<ysm*!dw=M#Su%G3T_~~GA)zR>Vv`2Ut}7d7Bl?@CSL)pl zk~K1v`wFEwsQLLzw3UA*(!fs6|0U#(!DD+V1wC5!yuI3=-22>by^k@}+HBSDy}u2! z_`aVvwrYHN(pkX_JIcvByU#`>>E%{9Bm*P;n(s|iQJ!gC@3wvJ{}y@0(cr=$I9XCw z_Tg1R$2OYoa;xWk8Mr3`at8o}#L=E1^K=3OgE)4R56twjRvVsQ-yyG0|Ijkvp=vlK zSnzbWd!}=HyoVyf_l7~Qck}*YoCz#{-1}1(n%Ck)0%zsrWY*s2b6@s(uKk~BcJrCt z=QuBNesl2QE8PF`VWl9~e!B_Nu49`S8R91{Qkzj7@Yk`r4%ws7R7+)x(YFv{8m_RiC zN?uYg*rvs9xWq-Uuf*vO#{vF~yEyDm3aPROVhjb<8B(lLRd8QpImPpdJ5s%>>0_j;I&(qJ34$7#$T1e>Gi( ze9|bcS!e#kgAeI~slTyuxb>Wrt#!o{@mWEQ99Z!~IMvNBGg@pD;;Us%YNcw)`>4oq z+i^DKa@3w`LAlCTDuxSjEAHJ(WmU&58z*K9Nk@m~*}_=(mnF$cxfATgeSh)@Ec6aN z+mqpnzqRiNP132(PAWV=Q5RvX{+KH{em}&qm@h%&XAi|7O-W?!DSCX;hNdYsDR0~< zx(g9HNsAsS=lt!G%<^KXM#UxqjY*Wi)I2Iy8*!C8imp?xtGniq$?d~W@%szOU6H-d z^&k?~q3;8kmhnotgn)zneSgX)h#mD(v-R?5B6U7-c&2I$z7Rp?YLR4%n&pG#wTfpc zEIx^3F70TgoYHh`X*ccD0{o3k&es% zKI>FR{i9-QO>aYf@}i?FU%?J{rEGbV%M!$2BjGz`otf=;&hbK?&JhC4(SHX7B<-6B z{yY*rIh>?m*7ep=_g(w*^D|Di&wJqG3TxQsICsv+zsvf*0BmifXMs}SABKp0y*D)5 zZtFl7{%&p6&;5qm?qK-V>xJ72iU&PMV^bBj<%+Y%QMMayHq&5E zZGS+({BLvnKi+vukGCs$k);KQ=Vp^LY&~xwgmw5%o7#4JkD@0)2Qqfp7a)1<461*{ z5Qd@*j7otVLMOv;{k=}xFmSWYvdTivQ$t9j1_f+Fxg|zK+&x5hEWx-bG5SlUoav^> zM;RgMA5b-#me{4dJ=3C_lq@UdE9o_FdfNJtXqWl?1v zq>u>o9xaTsg`+DgI$Tdy$?9R9ji!k=8i3RxsJi%%mGKUz*d$6DcN}=Wb3T%U%~{lo$clOksw4D2&=Ltd?z>IDOs_d@W%+El zOuzj5X;IcjrDWtHCud;A$2WuE?W~IZgr1z7oJP&fW)yS$N18;-SI03YK3!8?IRb&| z%DqPB+4YaZ`OGL2BmL>}PX(*XO^Xw=9N}ow!FjI-GSQ~hR%xNL34Efym}~(RYmBCq z#jxyrAcytPXwA}vi#uz#ByL4D5=Z>#69GK~J%j4E(Ii($&hZN46=fd{pAUZ)P&X(Sc=&2Q14q z5lr#p*lD5r=dCwf?Wq4yXWyq%-+qu|5bzi2!t>b@cMR$Ao6XYYD=bXoY@)3yPfH}#bay@# zYtp6>WN2Q&RToQ|Fj1q%O0b7g4j`A{7)Z%p9GHJq9MFe`bTSmnrQ)d%}Y)f2K*g-eX{QaIJ93Du)NwES6o;9tD$lJHp8 ztjTYzD+c`}f<;YJY)oWCR}vKcsg#5F|;$YudHCq;COLhJl}h+CQ3 z1a_o1dUL6!?>0a@mLPh3GZ_^q_aLKvm#eo8gu*yVu{s zw18MP;8VYn<#IQ1PGXG5PWt+>afLedLpXgs>{FIMjmiL6ex{%}{Lt)Dxa)*(@5(P- ztBaDcL5;FK&YdjYy-pL8Nz|F9eiR!`u1KsUYEWFM(|$Lago9vu6dLKzH5^Xw1Ix!6 zul!F$iCH>4$S+c!+Y*g18_pDq|6mFt=Y(; zSlvzX|AnMa>BB-(7fRICsJS3V^4b-H%u`pesDv4$q9+DlY zLsenqxt;pd#VG@Dj%;)c!l*RXVCJ<3W&z^;-VtoX1hAO3S@5a3=zEAJc3W9wDwl~V z!uLlt=8&Lx$Vq%l0lClvozo=QfWb)WSOS?PL-1h~WowSBqST?4BJztnyqviQRHH5_ zg$;@NnLfu!QG}p1Z!yzRKSTm)l-g<3l@im>96Hc8C@w5t%Kf+C|9raQssDRV2dr!8 zGh_qA-xsI=Up|seUbc$P1$O zYpyrLE7R81Rnb(_)KhOczn%WfqM5Kzy~eeygN21<%&ururhM=O$BgyEWoOsnB(=Ho zB}nA&*Fq+8d)r#?d`)T|f8`N+oR^1Ku*E=rTevm7Ln>Hs#Ps~(+P%NkCOLsbGt#Lo zqFmDxkgF&&6+B>MJ+RbiE*55grBp_+Q9;7nkI|?U{RYpr;$r=yv$>`b(*V7gG9!Kt zyoiG;c(n&Q;TbKKm=p~WZ5TK-Y<}ZPWUY|YEMXer_@yMuKBCh=9L^Y0c@_0BATdpD zYJBfH<8vjTq6X777#@*pgjb~Hz~}mem|N{dAA}J8z!Uc7-QWjHTM4f1y$3-lfrW?_U zL0(LX?l|})euG0&_50k*Ozd?Jm7xA%!S3)yS{eeN5l2WF(F^@x_>aI-demj|z=DWe zwy=u0n=~cn;TKjSR;)|J3C*{^BT2rTYe$%Erp6lhQmUcjF=9rL?bN1san5Nq?piT;gdTtNAsN@G)yo5s0PV@(e&w??$bhrDbYHJ-O{c zBFOVk(~e)#C$Yg~z&Dwjo|qUQVefoF{spEAJY$cAsmsOtFs&(z0xMaW%O-XJs7~db z2~{-=TXFHR)n1}RRD+X-qa67TLwEer-9ZLZH?28Z|pb>H>E$Z2! z(D>d5R(D@INIrtKWT}zp#-acHk8)++x?mrCoJ&v1{>zpohfE=g9IwBW3}K!WJU2|J zUTb;V{oF&_D};h_ zD{ijGA--TgyvEnP-3z3>UuZ;}an|u7sQ;_1DTDtv%c`-_yEIZP;{Xe{&gBRgD5W}C zQ&KKg5}$BU{RAywl>ca{{hkDR;+doOPmu{~a^5k8XX=~heB{GUD(ZvQH?ph>tdIvg z`4D(7y&;jE+?=s#{OZ|8keEX*qd2J9MbLlyODUAvQ?pFpwmi&mX1#kfupn6pr^Wnf zE+HfprwLABe*Z{OF#!S5JG{6ay@DTY@~`EU<}*!D5sL_hFTy=b6VhX{Hnff2z{Y*I z<2*7unecreWExd;{}TdJ>XqNh#*XRlwxHBfsTaWf31}S*M!(esBgP)Rz1{a5b=wvg zSV<Ld5!ohGVV3>0h?(Q5CL+CT_eN0v5)gT?7$jB0Ld zwM2$+jwe|)dV((aB1I9%0r#!(OIq_kOi0uN>nhrRFI z!Q5jp8B&l?Y^r)a{?Xc=aKdKt+rnMq@SO_TL%K?GONk^+sxg?&9_gPGlXV0bmE(c4 z>P5XYunzh!Z^cNQj^4?=U8d9BZqt)mJB~3>u zRE!00as$QkuS{tYbWi4_OrSzfRi!c!S#T1Vjs(J$Bx#DsbOav<(P9{iB4gDf!Bgo?3+V^sNyAP4<~bx2C*%|8WcJEz?VX3Wm-VM zk7=5Ma|mzn{&?C@(FqK56DPG>qQg#p3pSuZFmYH@y=CihS+gz}Ya#`um|w+CaH!ti zUZ3BeK<`-7gniXdHF%t5ZqoEMn_NvR#1l4hU;n=fLjPC6rvF{RpO%ZBfTZ>Pb?e=y zyQD z^wC8nB-wdwSkTbT@OPcL0kE>*bJXEQaMM$<%9!8QS(D_Eav{e%Bjza}; zYB=Y?!e;~vio`7G^Wm8weLKcUfazq_gA7-xxe#c1RLNWO3urpwwaM<9exxD&!jiVA z(C=NbiwRuP7#c^BFsQVy$1F=2zLtT1t}6ca-nHyL)Y<1BDZ3Mp<~=0;jU9;Zm|9pi z!G3)n&$23!V%xc zK8Uq%;clgt!nngh3GZI)^+?>PzG!K2P|sg|a&9tdxX&J*Dw5^x*J+>*Lb@%FkWs|N zsGgvv5lpfeVkk|ofyzhUumYP}_X8rXT%5Am56-{q2%SP@oL^1hoRBrNvn#h&rXfFOxlBY76K0f4HTA4Wbl%~nCAyHQ7%jD z*|rQAIW|rMay&a86z(e@NDOvMKaXZq2k9^i2J*ftNUM_>3add!4L99g%p)q`4X-ig zzz5MNdL3jy+}(+b5Grj;v<=^~cdBE3Aq;cY9tZp*>C!#Paal2KlH5ZEA?pXU^o_ED zUX-XVI~pWZzlcKOjS}K$EG!ZI%;D$*7A4Xq{?d=Bb(Ub|LwPM>N`*(E)&lXLR8 zmA_|dhu}y(1**3RLaCRHc|1Zp1*HWYmGxSUxaWp=HH%*Gt*92{VoL5f zGi;9Gt0p8%^OrNnFXzg!d;K<4-ND8a8z-ryA0x%CU4lMYf)CB*yo#Wsye>NxLVutGo)T#(TQFsRRb=QPKAN9FOHsyP!w|D1B7H`sQ5Cm z#B`ykpuQ8LWCV(0jxtK5mQ~f%k+W$tru-~J)pBi(p#{ztf@-J>;&8;b3Z(hV%$EeZ zo}NC0ub!&y1k&6%{a5{$E%;ae>|MA!>rf3~CFd@acjxtW9UHu-^(&m_r!-O=q+L7= zfB!quy6@&|I1_ViN28M^|KHC=uv^bGYrDq2((~yq4-uZ%s$8ddT^7zykG_7&_AxQe z-@IpVP_bR3W|34ip?JTTd4?T{FjbVm9a|hn-d}{}-vVF23;usqbpC&J{s`rgObX`p_%u;=XOA<-0xq$wE;bN8t6(H~itjUEC#v|ki&Vk*J`iWN zPxWj3Qa?m2K<-n6BsZmSS#<-phb@=Pw;PI7{fdt-`-Wg>RWh7nGBv}sL8z}V0pyc9 zp~EtnH-$|%lJXLg{YAv$$i)t&Ti*S{Wctk_oNM z2N|sHcwnU(BlwQ&BnIF1YvIB}9aVB@MU6+cAQicp!MM=MI0y^gV|BA?#B@7-WGKXh zH4QbCKFB?Qe*B%a2;U?8`sNv!>lmm6F(#meV^~!(e$&&QtinrbM!Le$hh7)JB$krJ z#8}>JG{@KszakYzo^j(CCsg_4sq;|vacnhBwLx8xI$_X-ZaHHJFTxRrVdYPe&f_hN zxr?UD)J4U4cEe}<4~wcApNVRCyBFlFr@rkV`L!RS$a9Hcfk|papSle80i6on|8XJt zl2zo3;z$wROOe@_2$A!6)sjL~gKe>3p^!dAa5@o_u-J~p&W%y<5tAV$y5i9##SAJs zb;U$3?6p`P6hqy9rww6^$ha*S9XqXvAi|6~%TnTW`wh}iJEO^bi$iz3Ps4?WoVm%9B6Beq!LAFM@2@X!iOr+PB9oHP544&6bQiE3C|+* zQ3U)WrerJhN@zH|QDj&Rlz8VseGuv+TCV^CNx<(%HRVL3SkBML9s?(nkg-rN{O>{_ z<_=q$r|5y9AfaQoI-%sd0Rt=_ncnr)Ojc@ODcxPNzCX9^Dk?rZ+&rpImF}~>?c1su z!A>>8?e=q~ci%J__lLYVdKvU;$)S>vCG64B7U}(+$Fbdte(+=k;Hf=ZR?esIv6FEh zDUWeRZb%HZx`H&2G&b8Fwe3qD+_<7>S)!lJ4IR-Zq-yh{F3OW?Nd-eXXxQKsDLM4o zX5>-J_RDKvHR^tl`_KTGiv@%eXoOvgZH7ii6T%paCS|l5(Fa!-8gFASNg#3t{Mx2( z)^W=O_q8+!(fx}4nGly)P8341@ySTE7!876FCvIcH~RX&{)O@{-oG1|GI{p&Z%7~y zJJvOA2yPI?k#UYYuX@+p$T1e=#M}>$#=ApxbssDtt=Ms2N_K}8)iDyziO~ax*-*?> ztlGLsC-vH7MG=-0Acv`xCG%|XMtj((#gomHAc0~LO%X$ODpUar9GA|Fb#!V1NFN{pGRtu&Vv4XePwm@-t3!P)`oUW*O24G_RdG1#UvyqbjaqvQz8 zZA_Rkxr;J=rq78GO>1MH+f$bxxsq;#;`6BlSOUw?zP&W~Mvv+c9X6I=y*A|O=Azt3 zs_+wMS^pJYc(A1t)9AFkrJ}lGnk>=)y4Da{oJyc=g-dBDS!)hmiRj0EI+|TtrQDP6 zGL>3OT&j3yGm%*@7J{0M8Cm)RIK6aaxmA<;O%`;lckfh_^`3@Td4R^|W#SP2s_-{M zY9+6FUP(bw2`jPPfqP?XSnb+Aq`RWC{ek!5%SjrHsZ{oynS9?=w@lQH(@{iZi1GZXPl4+#A`AV5$5U}m%F4`QFYyjM@rh|{-ksjD-0WO1iFq9Q+)$AtH5Nk~XY zOG{HJmCpu)Ol7`RaF&`sM&0{9K#BjG^MAf z|5^KaNllVzEXN}QtM2qBsNT?CLQTeu4rWXZfVMvs+g>aLMA<@~-_v(gHBl*gq{gA% z5mNJvBLobOld2XB0b}u?HfA1*Ohdf9KKS~xo)zHu%0q5TEQ#18o3p^{nUwcWk3{1Am z`GK9b{l9m20WIzhZDrsGL&xM+PML%;e?~gGp&RjLTP$~lib$ly+ZTfEvWf~}&pX;p z?Syu!W-A;iPjphgfWWmz^XarQ+FYH72M7LJ=i4yc@bK_fTbS$bP7cDKbC{%0m#TF; z;1GDt`YKg6074< z+;aySt->snz=L{2x%~Dli=oy+_TW=#j5WqU`i3oAMcmQ{rCmgd zNcJ}83~oCAt>~jNu(&%JRrueytxD#{HgduQ;G>|ZL+7DW$ka3%qzGBlsla-Wa6|3u zpd=K`?E|9prSUw+iZY~syCyxP*}~2Bc-*o=SBjPu*6O7?Jg~FUB#m)KU+MxhvyoMD z^Y=A4)W9IMut^2-IEcYT!9;JPL9D>?WL@06V zn#cgZ0Q-dPtbdPJD)3NRbO=kC+1)SI36(XsmQP=oOh(5@BSrNyiKRcHt^V#J?4;&r z^BxWURoma@Q$&L9_8U)st(U9M($XIGRTQefcY7~SWp6!Bxou`WF9Bk%%kHBeoIYVU zfEO1_&x>P+>z|y5&3mJQLMJCjZkzQm5*^(>6yKKU2m*WfoJL~zCVst#@Tn0QmY%H4 z?AwRE*jmlTriXWc%0IDos<<#dHqMXAGk6J-~~EMmVG9@sJ;y z!mR0ew=yg6qvv(QZHtfx@{;Yg*|!=OhKr+zg!Gvoy~FK~=fOcLjY7KPW!F8Xtvoa3 z(Z7ogfd*5!io#|sCCyPBZE>-tABz#8taZEqJZh$=ck%uj`<^G7o?fa^trdRqvC1Wr zdmscoo85&zBou=rbbYnHtmjc|?s%i3AT%V;eH^Dmo)N;J zUX=#&eZvA~$PT-~+F?Idn;s|I>>#&G+IDRRxC|hzpN}4o2M4AWu_-A@NvSEg$tgw$ z3ORxf=Ys%S5qorSD4)ZLv{+v>@zsS4*6sumo!IjN!?*VaV4CKSZ|Yj|5E}cS$k0Z= zf-ELin~yOtr7D+ffsTIoFe#IMH&HGbrIWPuSkY)AFk2IqZ>n*0bV4T<4CAofSZ#K9 zeH3!rYIIp!8bHL5&*1bJZ>b&?^4fG5Cg!E_FKth1^?iqARRY-qKIiokN*}n`IoOZ7 zUtJw$f75ihmAX=rIdI;eZ-mB>plyHvmQTQ9=4%!`CIS!!=2k%DEPOWu{<-0r8PyQI z)d09+eFbF;vF`7CKo3U|@|ccca~@VySldVBfb_Ax7as!6hOk~8EioNOxj)0CK21Ku z=#1lOY=!Bz(*SzZb*G%W*+mN3k6&3?Ar^2gBIP4L>ql6u+ZlNNdnWNV6p34_*O-B? z^-XMWa46r77tq%-)OI>tl zki!4>0q9lbz8n86Dgt1ywQyp7gQeha-)g)5d8)ZAQec39c?9G9Y~i3B&UMGcl;J=~ zEWSXAt|$!uGt0{fRp+M-->p~OcIU%I*aWk_oJUe0=ZBg6=gXcR-Hx3JFzYW#kPS;u z!0g8Zz#>My(~()P)qd~kulr^D+tZmu6y9~QPKjKC=!30>(>9qeze((xMI1)!1!j!J zYvsK6*)*8^Y!sHwzQ3)qU@zXN0LTd(m*rVmLX6ZRQ#a-^<7f^V17H<4dvhH zp`WHkc%=X)Avd=8&@XN8;e}3g#OL4kaf&1DHzJMRF#X>Z)E6aD2v5BSassBQ90yMI zvk$d>N#-;aNZw0}a~(_46dwqZ6i22=HJ;ZL`UZ04i6oy0kkoJZPnf~4| z)K=*=BP)5Mi5na3hkXJDHh_3p6ywZ!Dg7bHRcfQX`(YPFBBsRm@|xz#medAdmXCQS zvITy&7VRS{2q$NzzR!k|7JQ;O030#JhJWp5f>z`AS$Mr%0sV~ia+Sw%AL6xoz2&W$ zO`;e}cw-&3{aH%&F-smX((4M%-!c6+7;j5Ux^|M*9}Wv)pHynIFo}tEm9-MGT}phO z(Zh+ca^LryhJgpTpOp}BJB6b^5AibJ>}+a!1-*^NCR+deoSd4H&FexB0_7qb-mZ4K zZHn04hAIdJYd0n*r9kd>BzzwR8X1ibt2E!QFW;HBLT78mU^y5B2)z)HQnE7}0Y5cZ zMD%SDtmCtT{R1C>R4FG$vQ5|2%szXvkgbgB2gINF_+84%0MdGRaPUWA4FgU0?BBiN z@@XF!an|;Hp&ko&>`ANDVz$=zZ!XuOg-QORZs5swT0t!~DPSK{zZc}NS2uqf!B8V< zv6VA7ODooah_;JjGnr|gCFpV03mrDW@54{c{V(AC#pBLnHdRz*RBPG!#v^71u(koc z92d^x<8QRPvCz|xO^;UAQ~(DsAM1Zxd|I_WP5}LT*rOXAr46F-|4$QZj(&Hvzt5=E z#LCZ)EW~nmGdDK}{Paf&>eb4KmSkPlo0nesA?30`OIu@FP#0EgN<$ffMB~Dt< zLwz`tulC#MU^z9p;`PAUAVdF^h zA4^K=%8qsFO!7Iv5o_0tgISp5mR|$%gLU(x6%YEJJ~ldu@Or6!k9&}nzJ+JWpA7-$ z6~%vb3r;x6;~@QR$3^L^pDU$NLwzBWeSwHOHi8=JkSA4fKm&AD)9*h2_58c1ryfLD ztBqZ{MST`Zs{80@ND$JC8R(n)z1x{is|jeC zxNWl>AaMgUKfFSOh5$h6rbJ(b(R{B3%rBJ4t1acG(Cjbi}~#ghuqB zb~?^F-AMM@3Yu@3>}xSCrNpaZxelKRiDQP%94xEB97Q=1xAKIJV@O{~(Nf~DOp6vi zOWK_b)sn}juP3RfCQg>WurdiU!p{v6{tGjdUGqQkSU+|O-75np`KscGK`qMahYZ3(lA2(|$yK(0(xqU!IT$H7E zK}{`L-xWw|3qDRzb+V34OHXOws_p%X=Tdu+xxSu+{94e7SB)_UR!RTk zp?Owb9$6shw$~2VQc@d7-G7%VYrCK3bKkHt@lSNfr%q2#nKFNN-d}Vam;5DVg_7r- z%tt!`x^#ps6im-APUi{WamY3%)&7VgTSqe}9|?cAKsQi`mg{nW#eRfOAf4O4xYc@oeSt&&v+q z!>V+P%4)4+A^lzPZ?WAaQ@xskf<=cilKY~6D023yKxVtoulY1o z=5KvAD+7b5jnB}brp^>;*{59;86{oK@Cd;2)8{^m&ue~eGfUuSQxoNXaK%py;IIYQ z;E9A&+yFw(^YymI_I6}&nf!Siuhj(J!*aEzfc;fo?#nIcr;M!F8IezM#TfGXcPL%E zJIz_sAFrpC3PO*Rt1-`Wb2+|m9-P&zDCao>`ns?_8Jcezs)W z_UIwhe&6afrK`j?yV4QHd=(;>IWtELf^@z-V-so!xLkN&Zn*7zR$SV}!Wn*YR$oG4 z=~Pr#MsDuS8tU#VAa|a=)~*ALBb>MFseyE+Jlgne08l>*8JsUw7^b1zUOm+na5F63 zS!^^X<#&WL2l548FJA#a8Hd|R9qh|#?G_skfbyM5DowP%-|Q0jX-mrIxDUv_FwQ~t z=PTF3T`G31JHIwM+`1n6Q7k?~p5GA7rqcX{5+30Drq+utgbvraj-Ci>5T|qXyidl; z{e?!{lYT#D5NM5>=lcAc*%`oq90>sPNU-kuFsyBNKtTW99Ytg+^EZyBFFc>+=9qOm z3RpTpOn)BA6T1?;^5rH{H;tbuJ)Po^;|D^~p8vsOm}HGT%;)M~EVN;{T8MJjYz%X1 z@BKoIj~C=(pAq!|kjtCf8<>IHNp^OQx@Gsvi;V|DX)?CX>tBRqkmVR+U&sMYCun>B z9k@qb*HWwqpZ5@#TR6yC8YXWMF>BQsJh#|BXoRx(SW`9yFU8z(fQdkl0tQpLf-m8H zXlC5MRN4h1Bee|AO&oc03damc4pPBtFmz)SeI|(TX>=mVAE749m_H}y=LKO-1WYob zO9^UlUQ(N1`I}~Kk(82 zKdhZocV%6;Ze!cF?TT5kZQD*Nw(SZlwr$(CZJTF(ZMXdc&S~c^H)(6Ec`?@*?|7cx z(}r{xb%1~VppJpaXraWRu};8HbYcrb|G=5Gp8?|nB!)S#CA~zFhY4_KMumjJ|CA4C z?BvXhc_SRn+rXU}k~;42RR1f<;FY6DC=x(Z@KE*2K$iyt8E`namBN9OCmUecm^n)9 zx8+CtGDL_A+DtHoPnQ}!&|to0%LOWeS=uU+?Yw7ujeq(Qbr8mwl5?%NtRl~g@6%rux{xYPlLYYD z#N_l^E%r;UC^?n7%2`FsO+AWJ+9RTFxFsS@#%L*wHD%~%7zPc7R#q_G#s#i{8<&O^ zXijq&Z70_qbBml#ueS*zVYRNYc6DKHxYxenYNXv=1=#)*a$#@;(GcXGf}_F(wa#hN zEt}r39QV83F|9K9`y6%L?5kwqF`{haqm$kM6f*!Kytg=(^zY9n^kwJ(5BT#{_85Ny zfKk>SZN{~>b}Dg?WD9jE6vV+sJqrnF)V9l$a9B0jEW!A8TIb0-u4Hko&aD;OYFmid zU2s%U%3uk2Jzow*VH&3`+wOc)p!N$wo_`^3kEm9-@Ax44Izd&%tToh#2ElfXHhz7~ z#r^sDdf|$sm9Kk(`*^+XJ|WPlc~2tz0m}==Kwc$}e89`>ueEo){LAcM^hHZUTg%#5 zHFwinMJpbQM}Uf|fm%shZC}Odj$mWq=0MEh0NFdO+mYMR2w;5ed_T8{hJ>_=AMcg; zEq1t{i^TB;4%8Zcd2Vjp?fg9L{Heg?=;D80Em&0_B5O4!@=Q2Q&u<8lt9aVnO~UX& z!n+rwnq2hNILPJ0;0`Nmkbwb?5r+2ctb72?HW7 z;J~qL3p!Zqe8SMO=?R;c@~HN z3xU#nH}JH*C&>Yu?)?Bj$Z7xC*nyz$sg(x&qXVcc9p;u~l1S^_e#t-Yz=x}I%?1j& zj6#Mf2sob;z!8@3<&$vEbRbsoQslP?>WZUrUmgq(cbz!8`ru3km-5zjBYiVwvu#vY z{%hPsT>vYpuDuSbY|eTS58}zIzPf|PO9*m#M%k*KEhxN7gKZaL0ZT>oWE|PrkZ$41Q&9O$;4GzwF4YyIT zFvpQ;U5ily&Ob~JtYJ^MtSl^d_MoT~a{m4RDDC%MoZs&Y!j2rkkb?{GzPE3CTIO;- zHt0E@dwjgCtR$CCy|;r@6Ri7 z2a*d0W`2$7t>QH-sL*RH$GS1IXl`zMw1Z1#ahNsej3J{eF<_)b&<}gQnj@>C6A=G- z7R17X3ShZH3Ecg{=4Q&>tILm6t)&0U-MjBj*D}c7niT0 ztI+pNg^%l#nvkndiPeIq9)krEVAk&QjdUV`GiykN9XnIaC=dQUWRN;eV9F#o9g+=$ zH$Q95O`Oo$R}W^B5hw#~Okx&E7^gx4UTd4^5DrfZAqOzlu|_Q2CQUDyiq`j8#$+^^ z5U*yLlE}tQvKP#gri;)JOBP-w4fY62GT5ReiY&M#-%=-Xihz&~X{9Z%p`AUC&hIg| zzEvr0Q7a{)LjwWa)Zib~??09i|`GKqWBk7JV z$5XG{r=KO9^XvLjO$gC`Ub$% zvio_x?iH19H?^!=_qSVzmT`!MydS090Ztg~u?}hhi9u~)#UavkP*cMlxGFTaBH4gQ z4c(8^lgtxHSUW}{4XK%h37KlX8>^1|QxWr~U0l?7p?PIFZQ?V+94NY2Gzue=!*<_R zxvO?-HTaY<*jMWq!)aN`Pn#B7+tAzJi zX;KMklVY-thmrL{r~wdi=@;0P;%F2B!IQEL7>&;Q-#O zpFbKjiD2DfJ~(WE6!m>%8LmhO1ujxQ<gz~7JoCIL3`noO zF8W*S8(;@1SMG>a(&_%}K&Mr&mo2n|>_y;}6R&*S_Mb268+_#Z)kjyxTsWu{akbuH zewEQ=yWaSW_TrKAQ-nak9d7-M4d#>D2AMZp%E_X)h1qyZ4QtxOqW*k3XRNcJ!RVQS zL^u*?G_IRqQeS6;Vvp)o!$Wg<-_CY=IBSa6TxRh)?%(c_(#e=vD-P{-D2E}S6*y)@Is_AW z*{n(W=m2QfA&qKuD3TO#PyTpfgGUnG$)9T{|LU5^KOh#XRipj}?Dl17s#YO!*Hwlm zM`5QDm~B4z8>DXs7Y>YQC=*>tzdoSI=J?lHr)oukSBVNceL= zEi>u-{wQ!US*6zzBf>d+p?s+#jX}G(xTqZipW6-07~NGAD46ei5cS97YUzy7_isk( zaN@Hn;PD!h?hTQX-NN>s{PTsmb$`)ykHP;{Ew$Gt6^fiet6uqee^Eo_dF13zbo1f6 zQ!I=7df{(%Y9^a`50uCDSLwI!$>`Xa$J5dZmD&(8PyYP{)J$-%6G0X;Jzos_4I1Oa zEiz_qU2|fQ+fOivm19Q3EBL z$!t)xIBPbIjzrz#BRA>%H>lIk4z9DlC8oKp zMN1&5;RV_AxCso*g5Hi)R@LmYm=b+cp`;KLD$Q|y2m2J5R03B>WL@DNq!bIZPp4Mh zX>g54vD-H%ez%Hd8(UWPr6^bAyU6ioYZCOHa1rIY`WDrZ|cew)tU!weBJ4R=2cu|$~5c& zi&6Ww+wJoa9s|O#iz^Hsy-Sq|7^yeeZH*^c!9(|hL^)_T+RQb2u$rdZKTm+3mi_(9 z`TD#eLK9NEtL|DSPw0&#Vw9QAci0Jb-KY-wsh>DpE`_ zs0q2;qSMC`wGDLYJr73>Yq>m8QD^jJ$AEU;*^qQ^kG?@63{eWUe^fK@RzePlLqV&i5hIzb)55I z@2ShE0hrQ_HXq1UypC=q*)fP3lQ;$jU_r__d#7m8oD&)7uqyVAdJ@y*#aOlEO+bM3`U*sGE<(F9%6rWa6 z1xT%Dqm%aK>uv92>CALI{)8=*ON+Jy7> zF9ntH*J`%i_NMJk&_Sb>GUD|U1qx)a%(rICW~{5GN)*+)B}z1epAc&q<7v>f1XqEY zioIm^Q~#!ogFjM0*j5*EKL2IGbl3d>r=S!qa;pdog^G|Vz<#`e3W~JtCkXv{qxU?^eL$yANh9%Jl{& zm{j$y;eW86^Ie+1OmEn)M8yU}ldMDEIIlw`aq%?sEGr}OHD0as4QV<$ZJUFg{OUlaPDfm|{x z`9o1X@8`5``FM^mTb1?la+0P)+#66pe4w87tIySY1AIsxWl8|UiRl^hEjFZ#Gcq}i z9dIodfiiFXVRhe1>Qv^A_lIq~tX!MWUwQ;CrHG>+ zOvVD~L;!FQRe?>+8jitnjsWJ z+jc(9_I*&-Xaa0>+O9(cM5!!$-GRh?y9B$heO^>+6tl_C)#1>;<(&eL;zuP?SCwxk zf?~f2k$@EtsBDj8LeBA;2UJrDfl70ZrLuWCunjVx6)Zk7eo*)33dH>t^nIRG%hhVm zXjRKa8_WP#S!5qh{Wr^9;+f61OW(Dw|NL@b6g1i88&?zlB6DPfY8I5yO0D$2J*Yb( zDQKw7_suN4G>o%&9j>f1EEcg#{4>@aKDnVR*v9pu`2D+(uAX_x&hC5b* z3$^gt-H~WkBoA?=ZhwHF4Bpu%VF1Ul>7T}7dnk#A&_PCz+V`ac6}gn?YxR}D8SxW| zQieUV-cg_sb$o6Ye`t(eDeeMwd(ZFXlh7qiy>-kT5{qu6_c!0tjTOr}M&dHPS5Ykt zbY>+T5~5NqgKS_AbG{8~h-B(C&ARP+z9PNP2U?Y{aB2($(3ZNa|6=8##3nSAmyyIxKli5n=TfC$MeV{RU7f3L7pL%Y7P4OgTfG9 z&lb~N(}*}xhvhka?}H5zN*Yw;CtLyuTvg>?# zJ^`!+U5xxcI%MQy4nPc=R4M2&h&Z$56?zW5+6c%t12Vm?qXyyfy?p>&ZwZ`xv#o=*{Z)c#>Bohkwdeh&$ zpH94i<1W`^_BiA@PD5C)?)cpBxgBlv*#Cq?VQSt@$wAo+Lce}Qndd~rMD+2uq7VXG zv<9Xz=v8W$O!J`*J5x-y05$}DD=qr^WN)==MJ1;-a`?BV>5a)%=zkE}5MrVur4H1C z@p*=0@U<(<=%<+s&)ED8bi6m;oKHA!LSTutqN4=nRi5#?0~8MAM-+>&@O9bwiK_uL z6tqENzw^a2Y*sU-g?g5T-Ww`C<8*qjlzDPxwD>e?sVXXCU~seUq2b|qlgRob!c(p$ z8#*?-*;?S!H^_zJ1d@p0{k1_@OaubX-eTb49V0rb95&0u{_F&Ewc6Rg+3=2i+!tLR z&j~-ZwYLez!|~pcgmnYc(N0~Ao(&y#2W2iL4a^>7pO9nACHSI)mb4o8N?SYr7>dtUr-0f=p<$CGENa|jk z(62v0-AFAh%VDM~YjEmJ>FHv~iSLmih+^sU(}9>e!brX#ilT-1majwmdcCbeae?AR zKI_lnZNks{*9`Pquo%#Mj9mGkcYh0RMOeXmT@>H$4075*1V4LhiyBb7m5 z3AsTPlSh~n%jFcsRX_Op%k%Qpcqc9T_P>RO3zgO)&mM%;RD?%Uo!>v6_mdc7&AkbK zD4_w1hX1`zV^~Vo)1Ba;(U#P`;I?!uXE}~zH~yW!|9`XA`z^RqI3|xpqpL*n!}E=B z!uMg^gtmXIJzDWKrE5JzWhGBJVrVVu^sjOV~E(i-(1pWGQg4kZeSY&$UW< zs;oRM@k^kQ026xsx#bUA< zyq~uN{w{c=pr5~WS6iVEDwaxP$(|g&X~4miCDSthF^+)+L3Cae^(8=6v{=`+`TczC z>`doO8A&QY$WcM?!F3%M#wv~O772Ypv>>*1CAjm5hUSJ2gGRa@8@^uavA^3F1zgCR?C_;>HRN|%Hke^;)>PeR{+L&ZAKn|LzFNL_(OHi-Yv2lec@nY z(4wgz9oz`mNPF$`--ZUwWLZp7tp+ppkTA6DyQ%jbF?d{TpcswYuv|gc>rW5aPI8;% zaz{M{01P~$Q?sEsj`to4)I?r#+PYXUg?N-j+ce#>*EhDq?dJH_UFuL{d0a>kD8b(o z0DL=&&3%vCl0b%-yB09;zg!*FJa6d9o=QHrEpiG;2|}?|aHyp@3LOv&a6Xw#=eEUZ z29~{3#y zMjw}wP@sf~+i!JeOAFrMi-bCOwsOZtP!Ov59n{I+Yzrz@YKINrJl~|!7i27L_#H`a zO(dTMWpQ}?a5~er;LDSG)rQ;Xr^BGPvjL!#^8b$0JsxipHZSQg5-Y;syn;Y%f+m8Z zD>7p_be~X`BT1B(RI6 zbsjw5{HXCmT11$K2T6N{CVn8n80wViJV>P?4n3o>IHH!=_SI#>4mDz=4~OPiK4}_^ z#sTcxRt*C9%xT#+hG`?Q8~BKmP%PrwMEE{#5`D~}_{Gc>0WdnMo*@=hIzztQe=J{q z?3UzMX%FfDOnAIuKBw$KU+*TN2C=VlfLt{=97%sF? zuMBBUo__7^Gvl>Q{S;y6B80PwP)Q{4?b^Ky;%&p(8`;2wbY{5{|MLAh+Kg5L4^fo{ z4oI7K(&WcRCg}S%&B1iq9tH(TOi~U{O|wcZD+d`viI(mXI5|AV=k|sIf@+KcPw8Cp z+U^AfS$Dr&ZL!y9!iO0-p8D70{^P5+JrauzD@JO+o58n3bi2V0sLP z6d!A?DuvJQf`wLQC$oh-Anu>%-=okn`t5@E1vbiJGKy4-q8n1wts=BsuIdG_Yyo&x z@VY8%io2pq+nazn`oQ;Y5TYl+0oo#lQRSKC*>>W*)M88hYm5Cq{;bf?u@ORNVP@<6 z62DL67E0WuvpJE%~6!P<`Zlqonq=y8Edn&?PnM0fn309jA~>oofmhdsyl-28-*%<1=2 zKoF#+!&H3P--FG>=BxGR(``}?4vbTev;Coqm1=NZBgR0fz8oVyBKupR8SV>zpyQdu zA?Z;p>9~ufni2&8Xk4IDoLVCnwq8IOoo73f+xF=GrnjR+CVjr!$E(+6CcKjh*HKAN zX3kFN`v^e0PPB@2C@U=Z*)Jly6|bh=3&0OTa6emj?QgwX+@H%#${xze#B&sJEIdIX zn+OxV%Pl>xQBg2Z#kHOy>+{bt$|7xRTZWr*rc$2CVdekv9=b;_hr`}b6FrJAS|AmV z!P(`LjKKcB8zIyNx%-^Ym4by1tyiucC3fG;i(>Xzj7?QBSA%gXk5bW+kv4D?G`W`_ zsx3Y7=G|x*?wW3fHN9baLzlI2Sq&gsK$UZ8t(U zW{`Gb((8-{L+%&8s_1-j4(7i-{HoF(8F%TmIPa@Df$YR3sMU8M6Hf=+Nwp^Dl-&$U zQd!YpVkh0n%V$Up$-6>+gqn)fYdou}3@I@OX9p2;9GYH=m5Oo`2bN?07Zpzu;c z4PhHos<*-57p-kNm@Q%?I$VA|dB<8uViqm!WP$5Xe2FI+;bp69T-EB0E~wZ1B1@X9 z^7mlA3Pt>a=DTLXU#n)y%vqbmt`TBJfWfDWgYT7gpz^*gc11jfBIf)KnpFQlq!y= zjua85G62cLTbqBZ%+kyKCf}q|cLbs5^@x${03c<(X&A-o;^xk8L4pj zT{iRFe}JTw%s>8Y7zKncpn?5pgw${)+|e+SK8hGz(80w@RqOl1Fp{p)c&tNC1civj zN0&V;u#TrAFT)+$fjfjy#wC@VoFuXLt>>?9ADg0I>q_Dj5JF|8O@+B|R4{09=#`49 zq-Ac0Ko&44C)3$0<*FBv=(ORVL(*x>HTp9S1l|Tp4aZXuhyNsfdzA~HmkTrg zYJm4S=q2qvQf2H(seW_GVI=KhiUc)zLFBE=rX|9nnX^0$?!$u0{b7^wZXxE zp0|u!WYzqfCtIq?AB4;?WW9D6ggSD(rp&8IVpR3r>+Y%Zf2&FFMGp`vg5go1jr z0oQ}v&cl=8<;Za9vXtrhGoSZql@?|=e~9}3V#S^m(z7e&|W0T3zz~oA>$I1$awsHtS=b^5OdR=Y<@GDZEE!DyBU~tHH&v6ZGZ%t(E3v= zQ<$5ZJMl`pxe;&syxph2{t;ZtbBBuwkKV-rgnh=Hubl-2<=n2b*7O@V2yfjFOA{8+ zaDh45VJQn4u^~8XfBbE=n`?8wZTgqK@NAAIR7-(SN)QnCZ;OsXC_OLw- zO`Og_&`@%CoFTAO%*!b=!$k7K+n98j{uC;tj79GcM=WI4dYv^n@8joXF0TWz+eRVuSlE?i9&_2WJFln}#6DI5! zNqf#xk-hjuqdsDH<#s6J_4C!T^C!s2usOQ(ruzeM&5JkWcPZDiD^cs|%e&7<${5E( z(&Rma^tJwhTK}nxwYPuHy*iSt!H}-GC|zvvo7|A9n=v!e9tyYECevuPFK|?Yp3_A$ z1U}X_(k4AL-wI=*>8D>*Y8_^k#dhF1$zuTi+_y?{c}vsMNAIkwT~_p>t36e9KD}VW zF|~8A3#B=S96ioujl1cYJGwt_!b&IZ_Y;QnmX*SZ#Ht!&J62=jC=zA~lYpBv$0jMDadue<<}vF_ud$5sMlug07ljkZ~r7c5z;5 z#*ouhk3&LYGf1<{M(vl$wnI?6zICFjyPDQNb#<5RaR@|u!g^iP3G zCCz65eOvpFIZJO8`4K^7GbmR*f|OrELpP+>Q5D`cDBxNAW!a|mSyS;AS!FWaKnqfX z^^a-0vu^1AIdqcCh_TZ%8&VoemQ}of92FHpW$9&>l+5vxUAkMdu%z2SJCosHC>*WU zj5?;HDt-n;OY~t8<4V;TF|gyMjY!;6B_1uE6YobI5Rch4j7o07)9s{7*+de;EELqxJK0NWLx_GetrTX4y zMYlADFdd6W=susXbu0Q_*?XXjAgOit$pVYsz=;sx;Giv2IXr2mqj->o<4qEm-F#lX z>o~vc`k~2sV*!S76W~jLyIkrM?UC>KwBS?=9~(~zo^83*u|WU(UY(kHjHo21Cql{@ zBN0-p7d9m34%(axOiULI$TT&^*%tvkaj9IDP{5FPq(?^?Cp~?ve6JSmuCEyKH-@QNnx-7!&d%9!A%os9uk>E zBPyq=*JQyN@=PU zm|LaW70N9HqxClOOWJ+M_XHoID4l6<+qdOu2o|<3SDV~cJ35UHmt

5`D7oY%mSZGgyE3{m#OUcHo6E~jAEZjvSJ{&Y5 zt{T&CtS2PIMc`&yy=!bf{SUMg0=wk2TYV-+91wD^_wkB4zMNafimnGF3-Od$Pd4lc z|Lg)EN3cvI8$4jrjgF7*?z6m=HF_J&7YYD^F+2Lc=~b2`yq{2{s_12&OggY%#?f-^ zXgf%$dir29)$#)0f7^B4gu05Fqm%u=0!gVu5TT#ZfKHCgObhi6#|U~cKujledC~@t z*@N%mw6xJ@M6T9(#4NEVtuHY+;uZVVNJB{EelRhGIVC?%m4RPR8x3*0qeoNNMTTpI zYfxBS;F-J5dI~M5G1=gC3(5{{(~ZVcYs-6u<~=T(L5!p{1;sde^;Ib8}=*zDM(h>4PfHte?wwcJw}r(48N2R2abW`ldplY_0^;i5L)PPg7OUOfqo zFelW6%;UO^4DgkHw|PNj0vLz5F?{9ua~fNNi{2RPSp;W{XVr2 zSS4=JC4zz7+5M3LcM_$%88M9<^%$6Qp)|NAKMM&93kXjo&-36h^O)f>=IYkXz}Z5& zdRExPw{>!kB44}B7>Xb>xVV#`jKbBrp5q6joHa377oq|p8S?a5>mP7&VRh&yJg8VQ zQ8$vx6DCsHkHxa2-1Y+T7&Fl`?VCGwc9k1^CICS7*0>F(GaKF-Q38lt`W&T;sByrBGbyz*eISZ&!cfS+YQ&N zb2vtjQ~dgt9&}&b$Z@%p&GvyGo&|%G^O)I$V-9# z2Cry1DO*h7{nhx{ys#T3KDJ=5jYq9U6P{wJ*oQ}L-0!syk}jKo$8GcV?%*%(WM?mn zsm3xsF7AXSyE;~3 zDgNmj6cXs-02KuS9%nF5@Un4K3~W5myCLrI3qbu-Q^tG3bU%t8x82a$qHT`2womXz8mN-_AHStuYgU)+DAfX&M6_F$WAVcWqM&9q3 zWZUm&@O|4N1ql(m_D{J~pO~qa4xEydWD@MuSO58U?>*k2YHSS7ER86}82jrqKo7fLeq2c)kYj^tUCM|=hd9*dVe<~7(3ENRMd>(8FLWc-v=F)@O?T8KroAIJ zrNZ6$vDnb7H{OS7Bvxbh8=?qR4P8I)T&ub9uacbTFqGu0 z09fA0?xi~T8J%+h$>T!W;Z<1Fcptc~o3Cc+0oQkZuK!EB7_(wd6VLU?l)w82hXDu+ zuuY@wgBdNJB&M@ySNYwQ^)?l-ei!*fW9!36B^1xAk`PS*$&u+IuxSuTh3g*J%4Z{| z_5pH0idP!qce&xfKo|lw9x!w5D8J9U!*%5P(UgKpLa8IQ*2!wzigxPI8TR4`upEP4 z{UM;ZX7ba$VR!NCvbqD{JN@}?0c6@3{VeAzu_<9;5fT$qJr7Tj)r{QEHjm~;2*3$Q zglA#KGZqu~q%j%m8up$6w|Q~6_D|9Y__#*(#6;mJ3dl=VxhVDTsx!oF%BO%v=3>}% zw!qJq`;KWBtM=}e<4LxS+uOgN&oG*j_9kn}lES(*yfseeXS!1es+0umt~Q?yj@^}t ztcf&WV85>lbFR8z*^9+?MNf}i#T!r|KC@!27Fv0U8Yi2sRdGV!VAXlF?5?}_3-GqK zI1hEg1!&Is3{rQonD-NbANH#t#8n@Z(I)_jn@rSrB<5nZj#RvUGMg{chPBt+_gkiF zy~%7Q=j&tV$9Fr6>+%k=`keaazWreGVh}dMTc6XNpZ*{N{dAE&-}in{Qc5$3Rt>HK zCDi(v5G?j`GY>SAQVDtNZ90n?K4-OAHWO3d8+!F6iqVYsz{|$}eOhPy*5=WnY2^J(SMhf*XbwQ@m@t^i${z^z=fxDPXE-?#xa}>Z zQLVmg83o(TC2m&Iecrl{aW`ev|d_qk&>ytrq% zQQ-BW01jRb`m-Oh*P!Azv+c7@#aBS5&L0mW8+uSR$5ujvlH=iUwc`_tq`=aP(bSt< zc)ihsNcXgOwqQP;4Qlzm#OMd$ry#2Goin>$g9>l3hpD9PWX$E}Hhn<=3=9%*4mlhYBO0M&{ z$x*SiwCqNcGo^YkAK>@pg
=@74HVXnAX`wyAx;L(xfMBd5`QN0IoZ{MDXY*|yv~)A?c%DV zBpo(Js)B9u`Wt`A&KY;Hd}($jfvDE2Rrb<;e)aV95Uk$dl^*Ys1sM6zj3*#l+3PGV zDO*c+Xv+$1an@HU5m3Nm5FI|6hh^kev&H8(TKxKv;>a z*^F%b2El!t^9cW5>d_i-9cMmrFd|iic!df8)JREt<<0whPlwl#PPr4hFvH>2IBx_HY92_)e!U$pp&p!3KQU&5Gp;ddZ0ci5%?2`DF&i;9XRxL~!T7{rrEf8+etg@PMtj!L zbIV8`5BnMRr8WDtF6+-9`?a0FL*szePCohMFMjchmtA(*V~;=n&98sm93naP))CLX zu{WJ~{DEAace+5G@eR{&nwnp?88lF5pfM3dbV_RWv}u~hD32D2=%p)gUV%idhpsr| znnCS}8`%i`Mli5pYYSuuBL)iEM1W(H7~Yb3Eit!bUKsO)>Gpglti%@d%-wSfr)NYz zgRN~Z+bCx^uE0@rZ(x@h*S0nezYz=@RMR7@vNh{f%hPz#m7U zQDPeYfhIQ;*~jpxuFs3&%o-2SpeSu%8?joLMUV&v7z36c`^ zriQM%4a(^NiUEERqzDj37#*XK!^@Zekp%P!DyIVFXcOWm0f3-j1V2)yrsniloSrfn z#~LgXGuYke7Vpy%cvKBXx;HU;0E^UJAv;7l$7{H>?G-U2tW<<6pz?Uh@)~Z!-EU=7Ntsce?`zh0dvJP z;K)Xa^Cz7;rLAj8qC$1Pc5OICASw9f2%br?>0aewz6JoswIy95BB$h}>^K7CEMpai zIdX!eObTmc#_D33Am);7?E+RZVq2C29EpES%*b$5C~GGr%kfYLRg5yCkU&M?V6}*( zQ48GI*+DwL>?j;dVHuD_t(|?OLr`%sb!@tv6tE|Ck4njEQM{+Nj8iAyx?lN_J;N!Y z>cC2YK^a0`9qNsU1{(lNz;?+iFH-TQP8DxCe-jtvGD(g@B{3fe5)Uv9=G~~v#z>`; z8Me1fSuTmUmAF?{o@(*}%ZjaP59ooUJLnQiWXI5_e(+lv7R0;X($oZG3KP53tFZ;q z@u*5H!Ip*y`_9V(>3^Of!}vXHZ^W$(&HphifNgY>jTpmiOxZWW`q(-rzPHp6dnq2r zO!K1EulYD4(LC5}%-&ZzeIo`cHdFDj$07vk^hm}Z0v~vpTgGL~>vo^-_Q2PU|5fIh z{af}n)bsv#6+K7o`3cY0qqeCYJRi6sO$84y{JwagP>|eM4w0;sN|ZDH*-w5FF*EZ9 z8^N=eUFOXz&f8*~vEGEwg+NoyMB*6f=Lyv5VT(URz^8)$bHbzDcdI?Af5so$rzU_9 zs0|`x?ZYct+Z!>jFyjN~3$w>UbE|C&)7F%3^fcwv444;YVVH)Nh6q!xr|DM}#vf+P z-qVc@+7&SkfXxVM8@P=ryZ1KljRud|t*Wxc1&jRJ*EAM+$t9Qk(?9*wIp>`7_Q}yU zuA6vtBq}^lWu9B#HZy~cHKHHggL*;k2tBJx=@JQEBXu^bYsi2PQL9BAlc%B$>mY)z zV2ebefvl0dtWZA6E=$N8@Eq_HWt}OSEQ#0(bdvx^f!+j`l4{+h1fVAoc?<-)ga`(q zu%dIWiIf!l2%0;oJz1)f`AZ<1M6L-I#YKj*3P}^xPLNWGg9EKAQA=8Zn*^%`x@}BQA|hpIZ8KOHI30n0gw06_Zl z688)uC6nZs*M>j@1rwyKO{3yaT$|)^F-Z`|l*DopP}!#H0ug@LUrdm9T+X|bfpM9( ziW_A}><4x%H{2{8`q6OcS4oGAFAl=hU29DV5AMQj)k0Qt2=WMoWV9 zand6Sa1PUzXh9i4fLf3{><2hzA&3%XWP%DQ5yPSk$f=+@CQz7}GRsPhnkw4WQGl#0 zdYBO4DYj8aJ*T{l5L1*L!9V48x-piL))Y^NmnjgLh(d969P#{Gh`#D>a96Es3v`szZ#;h4<=-!IEry_O<}uCX1OtHIHHp{xQN}y-1si z1S@(|sjS!NrCvg#H@@F$UG)n2`r;AL0Auth(VCDoW@v5jn+fQtSpE-CW=&UmxsBlW z_Zo`U(~@j}HJUU5r+NnV+=h;1b(oo0w6SJ%?-<`$i&co8ho3<&oR+fNV@k)j;Pfle zEYU-cSJUJ!zN#=|j7fS@=C;YeZ)|t8v5ei6A<#a(CV^cx(d*P$8`iT4k!kv>(_?xB z#~I?a#P0OAV_E4HdJErc59?XmK8*H*%{j5j$6VHlDc{eeX=5}mp%tMa-=%|VOw24}a;rA((%+`iJcMrHPmzPR4AWWrUIgF2tTIv*o z@^lnj>S)W6kqRLyBppKks!|@sA%u;?{(!)UEfe%g5GXJ-iI`(|U<|Q@q;G>F)c~c2 z6Qzt)QvvVe0$oIwyl`Z}m?VBsNhNbGXaA8`3j*clIgU_rY$cUg<@o|I5hIBRF@_*W zP|>CUKmor9*<-8$bxWIOJpQW5+DTPe$_qugUKCt3=KjhFCkO+qQFqlb`+fGj!H%mt@LmIDN8 z>?$s`^r{v7L*$gB#<~$vi!?_0eQwDLHoIgA#FhF_2VQQiA({snSs*o*;vDg?e8-d5 zm0bLmeH(dchEa(~R$pgRDKmN;OG0fm!x(iwOJWfM^=M+CeTF}l_Aw@0t{1@X`>>s1 zbLuvnHgQb*Xp3GCkB2|d29L<7GPPI6Zdz-3bn04C2ZTX|CkoB1u#_WMCoFn*s*)@Y0HhF+bp z;j{+B6QSp=(8VLhTM?1O<~P}FeTd}eKmYml>({^Z(o0`ehe*apBt36ufp1%3Z?j?a zIQz>ZtH(vqwuiQ++r%*CHm65nY^xZ~a~t1)32+9Q6iAFv9%gVL9b&Zp zrVJ65M0a`Y@!Uqs2KH(Vfn7GyKadfm5$z3IPgvV%3uAz;M+qBFYl!C2N<5qp?a{;Z z+lHA@Y+}r|JyxKIab{CX45wd@86|?pFXg8M>e0kN^9;XVRWt=VwXEOZFr7B1ZU>rV zkNf(TEw|i!^YuUc;ge53Ieq%{zxu1c`p8EfSO@l@2V$Z5T_yIE=3;bXDO_)tUssF=uQjD?ej1s2PI-q8Ui(Bc`HDQUtI zPRIr73OJa`C#I1UOcW&ZPf&mbh)s$?SCbN-Sa4IFf4P&-iELW6=g`~z^S4N8ga zWX_R1alkR6BL%Dze#wO&s7ReUQiWe+qGu;{(acoxOg0auM0%O=uNdEsI7hN^-GBX-MeeG z+I81mx8EA+H;nZf8XEG`-%tu)8c#Xpl#xRuAA3xfwE>Yd?=$wB4DCRAgoU=a%@#;U zc_1A@GcN52p0F-6w9#X8hPCZ+ZJr1_jfl-}Mvvbxb~Y8!#@PJ!(1vb^9)3fQ9%$dP z`NwLi9|mo?U)H9W`9idH8oD8R{9r4g?dLZPy36SC(9NxBgz;;JFf%mWD52rkqkW=z zJoad2dkCW{P><=g8@5h8mQ4xcwCSz3ZT)^bq1X(qDu>?|H!iY?t?Gb0W+aBop*66W0?Xf&`5C4`eTmJ9WR})#kecLv|AuqrD^7GF>A5hZP@OEoY z0FJbsJeD}?aoi}ASpw6qdq(K`(K-I(K97u4tx}Gj zK#)?@f&>B7@P!B*fV@`?mPwQz2k`O3k#aQPYf1$_37mA>PzCNL2M#)u+R!GzC9-HB z%Ns0G#q>FVZopsUT!W(lU-AtGT#N+cHt^15V^B)3$eu+CSPf85)rJ8pbLIF&Yy!>$ z+)`j16f%nmud7q#xsphcoegFwAfq5*z;FO+%6>Ab1>Iy80HgdTGvJqcat49YN!-{& z2p<22Yi=&Zfr)}k0`5%_{1JGRVl2RJ)MMeGNrENuM3PCNA1|r^M#%^sCK29o51K@= zZ-Pu!#9uv8NK4)~hA}9=v5JeLE9k&Qh)}YelNs?jn}X^+f&&uU$ibd&E|1r!l#{A* zYGgH$$~8bLU5eZEsRcp<2SoPOTng?1LF#;GX2dz0XlH@24Lq<|c8G5!nH^XM=P>hX zDbdsI0uus27dT5aE5IL>vN4aKLR3v@m%z#06>*5h9A^~dlu1?q1Y4F;lmq9vq*2#u zPL`zUWP@3NI}=2ilcaKSHF^5{&PXhzSgz1N;u=}-^6o^vH^IT0PH)u#QN?V4Y4#6$ zl8pzf%}$fu2t=Ne40m+~C{dXbL^+@$78{MmV1ELrIx|fQzO%&mnH6CQTqzt%+PKQ2 z&#;YB4z;;~xJE!nDtuSKdYOqZt7{bX8oifUsv`~rJXP=|0ItP9#vRvNw z?6X^U?%2`2+qO<^dR!$N*j$>(t)WAqZRm#6X0Z7U!#-*7r7oM3g!&yl-MjmH;^@@- zas0GI&O`Uy8qLs=)x8k?dJLzh5hzlr)!3V*q3-RKJ>wkh5#BFvrw^B>@?PGzz0$wa zZ=m0S#se_N^qUd!@Q2aS>>t-}o2F1uzS5T%c=;7h)fnT^uX2kp)lxQY zkLmR&z_d#(K}b~=ZvfbZVYFcDUHo72;ls3JQL zX$1V2xb#h)GDf=xB2HK z{}~+t9iW`4cBYMvxSG%Ogzo72kH^do$I|{eeGQ*9@nL7&#aU;&!$i9uy?P2)iOO$C zw-!p<1m7APoADKGTA!=2RQqQkJIP*M8+^g&>kop{yD6cj5On@d@U4-4w|*Jqrcurs z>i`NDg~g-qp#v;2*Y|4%uX>H#2wD}nJQ2R@)*25-5pYJa+ldEB05)GY9@WfWPySZ< zdgH_k@&K$}qw5B{;%V8EeHCGlq?}=ILNW92Q~tQ8JP8-1<9W*S)6&e(VAvrZu!)$z z=z+;%`_oxsmaKzc{kA?(08&J|mo_eH6*sTa_4Jro%e20AgPsEQj~f=mU!LEsDxb#{ zd9DY$k{G|A{No14Hrl~oejRgwvM(s;82S4Zno@ZE#+p?o&HhNK`CT7d0p*W!555?T z*bjF)Rd<3W4vyHRK#IU4)sPf?kr`9SYXlJqVab#Cw;Z{!Je!G#_h_y2Ji6@NlZV9W%g!5>CKQL$5-l#%jr1fF?TPQH~ z>>MT7&zU*CxRg(7m7oTCz;GSQ7uY(g=GeNAmw-x9P&aj=ZQE1qlUTr8AG>X!9Jg0Q zcrHirw-JX#K)R7g3B_-14vW$+>pU4Hhiog5>Z3TuD&f(y6p7ugT2SP4EhlP7A$`VK z+oyB6G{54b>>~W-D{wgy0&e#77sg!!{jF(H*53JJF+_IxV)28VK@6v=cx>PbKJJCram;W%TiE}l|Yk_by3 zAxnU+x;6tx_6>LAlrge&jL;fq4wp$;Tv%0x(xZi4*C8~NIssHIc1^5DJ^bULx-Tz~ z7Aq>qR-up|bqerM$Fh}v1WyeX0SA+$82*enx88f_$&HhK*-ymapG>uS9I)2H@XU=`5_^i!lyuY#u73=E4N zt0;tpgEcy+BkALK1~IxAuYl6*xYb=SyrcKfnh_+M0+=mrmiUTiHR|rW zcP673o+s@+v<4c#*&*J<4`m-3AXu!pc<3cNKR>_JsofirFhzINEVr(k4iB%eEZe2~ zAaIxD3i?t|(a=A=yGZ;eL6-X|63*O1H$X*rRf!X}*=8KPnU0f{ah3FHQojETEQDu= zn~E~@4Pf*7cLceDTS_m2!r_@5sQ30&wC)u&+*Wz@e12E+14bvp_3ut{k=8h&ALTY` z#Gl5_yHTAcXgAO}gbqyXaqT9DInm2d+^!!%h^gAJi%a-JVVcE@M`@!PSUmuX_Wy?|0ZJ4=g zopysYu~5E|r(ZtBAi)NFCC?)CurWvWWYabA{efOzo%e})Hl@jHQMxTuJYdTtxAjd^ z(t0$ErHgbD4ydyz@rW#beU;)6jnxIhWFI3X?afIt2TIhZz7q;s`@-qyHVIMshVaT+ zm-3`_mG*pi{7%J@YAW8f3nhTa96jPfjScbn`dHE+y_rrCRc^tivIMdrvi8;KE3%^6 zHLv5oG+70b|M*6%Bn0af_@g)#M-D~#&c0|d$_kWH*&VWteikju^4NI@spVk#pe=kfr zQHHbgTQPR8QeM@BdgYcsc>;c3yrd*ZH4s~s`}d`)fgyCiRE5a2rhSpyFhP6?EmMp; zcc~GyOu++TPtuFLeCY*Ra^|*SlbQqFZ-sgi(FA_VsqG~>N}A^%DEqu{Y4EfIIdHBV@4P_ME62K#xQsIz5Mhl`m`Wpd~S5N^Khv_Y8D z6k5|s%*MeO=Aq@XS#@qsL9p~e43KNZK4T~%Qm%=)3Q7~m$>ev(@&cLbwt?C?1Io2( za`Tb06Xk$S0v=6x`idYOEolN(3|R~LtXktib>$4|1N))v5XqK=-6+H749@Uv0utom zl!L1gCS<@7xf#>WMsDrwa7{WyGsIwl9dv9$W!Bb`N<#Bb$Kf8cRRjmz@dm?E5$N(hUms>QG>HE8RJ+h| z$)q?#h+%(y9|OWoElr)f)TI}TeBsDSrs1)@HaAU9!B(DfhQ~igf}RKpsI5B!Zz!l_ zU|&{$+o1QST#Tngx03?0j{6}o@Rvyy0MRGzCl(q_Nd@ddjIyy#ZG6G_@7~8tE6-fE zLN=W^8{Ex+8Y(-XaVB)5S!eEh9`B2RgO0lz>DbgjeP&h5YBxX&|%Kco{^UcAvijFKugF`$Xq z`C>B%4Tl4eHth5KTops-5(5E#|Axm$bKmlZ-=HCd zfD^hIqWXNW`%G?%Pq*uQ-i9Z8)Bk*m6>(QrTlYBHogsT}9iL25h5#OlmioOX)9w+> zP?aaeVAUPJm0LI!yG`7bGrjK;43^*6zt~B;yn2b#V`EQ$31L`|MC)1=9Qj2^oH(8@ zI8h;8=kPm}RC;~`yb`$N&rhen@J%(Q7!j;OCE@@23nc2{jSg3ws^Sw@UgDw^WQEBO zjhressoAXRarQ99%N_3}C|Q`yU*-}LZ@IWxB`BGam=oq1v0=Iax0TVb+(3M>5ClT< zs;!MK>OqRZkl=?}xZ94iF>`xwmX4c}pW>J!h9UPUs{JYzrPQA3ukDKusP<>nEK>?E z>yw;S9SrH_sdWw?3KF5VnBW7|VW-ba?CrjCdT-EpyHh^v9 zqZrfa$T^lbgBBYbv9Oqk#Wdnfs`f0@ZG6JclXix_?T~ER5;RB`5Qv2iMUv2x0b2nL z`(+m~&KM9L2!ryCm%V8TWiRx%dk8K(0gQ@f6xF?%z2f(D^+Hz+iqvT1mIJf}BAYXI zd!++PGu^i zC0!|WOB(!gC|wCX!=@C%q~XBXO6o+K-zbbx2`YDT1)cG&Ucd~$~Q!OF# z)lToJ3RSHS8#mRkAbL}q&R;64_stJwnO#6lU!DcyO@!dwZH3g>?b2gN)ng537c3hb zn7Haz50mi|&q4_LxsJGjOBwa|FNArR({BfwR&3U-hL4`Df=%pvx94w#C#WquOzMB3 ztrk;D-)|1PmRgL8X|U}2hj+}fJZMUDd>bZ3KHtE&Ce%z``f5ZUO{6Zfn<7`kya{!k^DP=79$kyzsEa=l= z=DxnJ=oeb=eW*lDT}RmWSWf~EZ_Dy5{P(jPXxy#YPo`>joHl)7raPE|0Y4RCp$YQN z!!J;|jPxK!MGOWq{$fsE^a+ge> z+H#gGsZ!EMR#1F}qM+EXIL3+$m+Hf8a*uDg>V3y8VbOlAZ@gjgI5;LNSFZ6G=>2I+NTp2Qg%Xu2l7R*_6;a}&WX+h`i{zH0bA1E z5-S^rh8*>NC&M&|Sh8eV$0XJMZ;3(GU%(5~<31hylwGDo(tAXUAV86m zp^7MT>6R=`aV?H}ssh5DJcn*ah5v%1NyDz&k1Xyhh$uBD4^AxX!RO`UF-S>dDMV4} z5SfR7hi@oHa-|=%atsiyB^nmGBr3Rwo}H%Pr4oxgqgxwV&L=J^9qS|2RP`yxK>9x{ z0LeHbw}wMlFA^1XB{Plg{AR8J>OlyItw`E_5_Z4_Ka+2q_y+#m{&50HUQREfg3>7A zxKt2ecGc!Ll8uxZTCELITe>q<4qgHpW=szULH z&aTZIN|t74drM1y05g@+8>LsF(FGxjDFS~I(xwALK~-8c|GGgTsRK@vLQOuet_nzf z?A(V6E=ak|E;&Bk(C~uQT^^!T34JoZmcS+Vn|!j5t1zlq{>yf!-KA2p&4-V01tInr zB8>4LHd`2Je&!{gzoV_(@zQ2org#lci_|L42-eUH)n0KT*Bt9KSgLfoK8|U#-1q0d zkItX`ql(dLH=nNEA6*?nRr{Wkd3Uzb@wzQdx($P|A%8fYUYwfmm)0U(cdVBJiM&v* z?KGG7RsIX}jAWO{Zb?LMJ{~pNVV^q~jriP`sH}o?14?4FL0#zpQn!)}* zm4l?|0%j8tZteM*p6OzEp4K)Y{TLSZsoN;fsJ%wm@8>Dj2q&s{VZK7Xayj38fr{+^ zJ!X6wlOynB?q_=?a>&S`B`pEQl}KO96PV6ktzye~>}`j34I0MtY_@wXR!X1?fhbq- zbnSb<=ADh{Y0RO|2Y?cB4Dov>*{B^M{=38K$e~5ARPW5;op)|7@b3+L>_rL|%yLkBhl|Bh0BRWnC{OPoIvz)wh7KZZdTiYN1h+18Q5yZKxnBRDvVsY>H7WC<-xJ zlfbg(#tck)G>lBqK!(P6^M3>W?TMI7^q3wv!vt~7yu)8uWkMwLU;X zntFo2h)N>#l@_e-bkPD~#}czEa4pwRv)jVT37G&7(E2mR6GyH`78meRhh{&lm(QfE zR2E!3(^sFsa#4Vs^3OE^#nUxVUfX`_TT%CQ#&JP#E!IRSmBO)Mix?+ttpWfJ+%6fx zIV7xp>2A@^Jk3Km)%W{i4{ezL^4NuudfCxAeqWexGbKau3=@HvR@_<4Xn7CIn!v-s zJr`hJH~(AxA^nZ;x#IHSFk9iz{U0HkzrYzbu=%|xGKBf>v2Vn9ERfj*TTy4#z}8|Z zf3?Gv2sl|$(Q}b!!#Ed68sYQzIU1NhQZ;y`kYhME5N{Kn+LAL1ZrOR^=nI9ljNepGy%3eVca$N9X$q0Jf7i=ZB$(SPxIWA67Mf=}pkcSF*QGhboevhcP>O9ut6Cb=Y?|XN z(&yN1y8olZUpg*NgDu zDZU3@8@Ht*FZs)LQjXp4MGL2>{sn$-Gk5dF7b z&k;tG0*}Ng`o7Qmc3{eV_S4fKDlwwSU3Y2pQ-A(+s||l^f7@kxy8i0hr|(@p>@D7s zeLSi=mOaj&lS=7AT)g$3_di$N&R-Ih{a*e(c5N=T{1^FE z*>*84P4>P=E8?*;Oq))D5@g|CguZUt=yeXX&eF-BnuSfTPPSxKVq%eb72jTvS8T8A(N;<;R$yOx@&G9FUSXJ=W0BIJa z43JX+3@Y9czBvi}bNt}yzKYC`6JayRME^o8RJmo8;VoIYScM%8F)Rn%1B7JwTJ2!k zNm`&51wuxK1wx~Y1#XT)Z9-`f3VvTtJim_MAKVhqRjVh7fnv8n9LIznV9QYpLm6_K z_XKk*!rKt*ue=~KFR|q^1S>So--67Zs6E16#ejPl`rSXEs_hUcn)j~>@z{xNFi%1M zL`WzYq>1IKbKD7*A~!|tAc~^CE^rRADh*@uB7mg)@-tv)=XMH|MvoJp@cN18WHf8% z4pl9tcib`(9BFas4kU02h#A9)Fj*EdeNUEDv`oV|S$@iyIO)kyz|z1+%rlx;wFU^wbNxzf)CIF%|;+>mCqa7+?T(i^Q9 zXlJHXB|*D1T*ACShbljn#b1k~Hq)3Sj)k^5NBCd%nRQ05VLW+1sG}%nYLDr+&)pq% za3II0cMLUl>?nO|i=;IF^Z<}bOO3h_ov!z)2`RH2TP;lqu)1!YtJGp2h#ESZdSYCP ze=)*P@*83+*$JY6Ft$6s7FRS6Qu!xjPk(PhB8$GiwsdCW_;K9c6Y|C|6695x?!C26 zo|YU=jha4FN$*_7gSDgFpH)usfO<#cve^<%>i6_as_w!M9?E<_zMH%ueqtf|fZkwp zsH`&oSqI(Md&Mn0MzOb5ZI>GIRKwB~L*VslElH+tBH4_F7<97fIqcYf0!<@!$c1(mJ=*1JlMkn4en_ z3P*;56W~6^K1l4kWY})_hfmAsjsqODh!GaQvpejjn8t^|N!I2s&w2h#m+)4|8c&j1OZYx&Qq zLa5(Shk%I`(T-3IuAGY!~0J2Z85 zl??8PB-)EyUMl1pyqMqRZZdVv3SZ5*UNv+=%zx#!5{PC`a`aeTsQ8(w);zB}U`iJ< zdj|(JsN;sY5rql!c0OMb5)!%;IAw~_3BjWIO4?mooH;-nb8wxaTB_OR=B5sj%l^{x zg_pAVfuyTrx3oB`b*ue$(`{lI+~x3jJ)8aUvdF{C`(}?SO}Rh4Cc^vF2X0dSTyuIm zzXAul>wn&i%Hy%3Y1m@NYPt8-;xdvxEj0N$jWWw8Gx2%%KzW8=b-Z4*K`tl@A}ZV$ zDEuJDiT#YmLKhli4)KFRom6Om^CD00X&gMLS;YjpFIkUcs$u?Y6vDJ_MK%@6vv~UY z4~Md(?Tr%HPi0GUl)FkfZS$+ru5CX#`}5FGxY4yDGsh<<%r$2dY=k+PWa#2gN6Ro_ znFXTG?px%D*;h|T=RXe?AfPBD7|WMH@%6j>TQ*J%*dnL+9fBC0uYoMPoIx|3njlz; z{KlpB8;%QAup|zPSW2W9c>qB6FQG<$jl9Pb3Qz7CQMx+{0 z8Nm;;U+~Azo0^8JvS|!6`Jbo{cFbFT&I6nF6oWx@i?Nz$00*UgdLJheP)^b~ke@+f z%hTl!;SQ+b6qYb%ObCrLT}d-MZ1MwC$wDuPY#o^+zxFRYcqbpzzK$Ix<~XnwYq0dM z(dg;7wLuK>EL5|&v&%ApU`YBHD?l9k*LULDZ1{wR1W_%uj!=0T%1{Q!r+gT!2Xw@O z^!x?R6AmnkA((CQa3bIu1O{*nkF5vZM6}NdMt(*||vd zFQYlW_71y=ZD1Dlu;HINfp0#y!dMyM#Gt6lfjtz%{O|zFJe{&cf%86_UkDgxp)k>C zxI{4Bf$}>j4xFFZ3k>ifa+N}D{z*_B^V@Mp1Z9y7wgH4-Yf6Ybq|=wes^q1g&hWfq ziT`-%`I-~+1+7Hq_?ja5kHe)@(NJ*l468n*h655HN`sC)8?#1arq9wqN2*U zN`8tV&@QlpBmmr{7}8m}P6+qYRby%#F}p z>U+X_0&}<6$c-A2=vY>B=k_J+k! zi5OhkgsP~vTU>D#qC$^71llJA2-d1=XPUA_I~e6)ecD&ttP)a&gdNpp=C69{HJYw` zq*|({TRxmNai@cQ8Q)SHEtb?=b%;#w%ze8+p}(FQnk(%8*-ThgM=bUqp5#hHL zQ_*1ru0(h0kM_%&bBLH@w|mHu)e&e({;tQ4OKPc`Q~gg*V#x$rHjfSDmb+VR7>Tfq zIGrTLLXMv{v(mURGT3;_oXyyj0f~d8zJ#>jqc|krAMnSVUPe##JHf~sM32`Qk+-EPqpA{0fiXSe94L|nf1%8spYP;yNx?ZukT?#!ifpxj& z-PkU#`*o{@a-UDwLr$2VZ+vDsU0wP4oay6n(K;+i2jl63_7KKJP*88&s0^kq*1cRW*5;?@OSSP^?F{@S4KO;E>(H~ft%}eva{p|H$Wx1s{i<=l zI$KsRR=QW1MNwl)YP4ORnO$Ao*kJd1D5p=PO=kh2prF}qwpM7+gTd`qD|Z&&U1G`= zdhnV?&RD4pE0z^-zm@s|y3O5OouTG!eBNDF-WseI#!?T?@x;WR8*NudWU1cxqIEba zqvZ-~tp;Px1>8()wF|fi%{fzmVkFQ((8Lr_9AO^n#ftrc&NVE}pl#7a?m87XxP($1 zGKs51GsUUX#`9HL<*XHM*^K4I#isp+#cE;%$A|J|m|pP$pMm4zLu3%rw@r8x*wNKS zI|=WX`eSZKq-gq)?VL~u&>MWc=M_b0!93XmJ{W49oh@Rw5i0CGtSU_cI|&{SspnBH z7%~J~Y~ejsGCkpaKaZ*`c)te*06ho)sB=Bfw^}ggP1wdTvuv;ghf{G*rdlkV##=u+ zO$fiXm}P1EEdGp)(Q|&@&kg|Zp$wYi8Qix)c=}^xZ}{^qThblhz}cCDPUO+q z4cdOUgkZLKwx71Mab!$&)l>o1s(!@oG#sBYZ`6pk$x{kEX*`R7s_82lbzj(TYq`xE zLVo?_u$g7E&I$pJ9WB23(xdlZSvkZ`Ag&36^^eWZhbG7>H|x!biBGWm(eb$O50;o8 zL?nunX}L0DGykXMVmBndWUzKN1EazN3+I+5hHN-D2;-|L7m2};03Zi9wZ%Xd z>yNa!^FUG=ha7rM0P$P#m1L728%ZXRJN=WP5RDul=vVhIB#*!%b$1pv*P zq$EdjK>%l&+~vA3=!n3z$XSMi7t+?8ZZF|@*uV;lY53?|(h=*9mfIYVhsqbnV~uT~ zVe%nPdyn5sj^PTg35zcsr)w7kaE6Bfaqt?mxS7(GWRA(`n))o2qVhM?hME5^@RAp$ zV`E*>n9J&egOaW-G3exQgZo+`S~c@UOL04dI78+eUTU<)lq&-H{?2UV@s)CZLkg2t z_FnvVgg6|PgddVyMvkCda;xrAh;;n!w{8!5U(^Zt8v%Uw{=Q-PiNmmKXZXVAxY_CC zGiy9HTYC)q=xb`TR4N+wQl*xQwAEQwUdOWaa|Pt3#h*##P78l@1V=(idKx7A<+Sn~ zX{1zuB(KvD8iOJ`8Y{K6mF*k|RwODL14>`mmD?*uT%~}i3GAVX-*WU0#y#&fGRmmv z7<`XN5@wc|$b8{5jaBUtoN@zD$;n;gq&(E0XF zXOC#)sYK5AWo`EvgsVekcn6s4(;w$&3O`taO4)N&1I?4UY?k@OllTI^-l(Z5h%f82 zX0^KC$e7G!uGnn_eTYnGMz430ut3GjRSaD4i~d*MvZHG)OV*qfU4b>13Q;U+b&*C^|hNSe^e=Fa_s*s&i;VUehNnA z&I>4ZiK`{keSI&^_JG4+%aohU6gb?GXW_eCb9%0wS>5y-1lJmFv)@rxyUF5(_UF`D z!Jx9I$mxd)9lx8|Y%~>HzoQ|r|`p6|+s0t%Zuik%nz=^Xzw@l z4+v3@-Vjuxba?R~VD9*@KZWXnvhOz@Pjc^-*Esw>-jKgnP^R59uG;C;yruNZ zC##f1tjdGO2EO8-TNR(zWSGVQ-}uuC^^pak0kJiXG*UFmOtB-AgmYUOnji6 z^|Fe#;}2ME7)hwuzkmpOGkae42F0LXP=WzPhTMGZzw)3g-P;PqvRLg`LB_S5$3vy+{pec+SIG-@l&Wmw>*>U)Uc%(5f}tF$PC zwXzEEIWabv(J?w$R$Q!?>t-DBNWCYakL6}eoXc?a@An}s7%0NreQ_UetfA1 zF&VLsO|yI2PUe&h)gx!{6h!>BPorqRV)Kg;0731b(ys*BK!6+(Xb^`+M3~SBCQ0ra zW&J}&{M4ce?@%QA32wS1Ab(|=M3QAYeAJTOz8wtQI*O`ict_~Gh;#r2neXF+48YE zRQJwHlEi5Rv`Z@f3vy6MR3cf9;NOz*9#sY5kWOC54$G(uFPhv=j0X2AuSb z&=>OggC?_)Cj9~GN=~SJE?23mr5E&iboVWALUPJbAS^NpC1Ud3JgZ;obT7s*Lhsoj zr^j8iK|^br;|l@dz1(G-q?qHifsn?@CBahZAQj6^LH_0{PqWqgsd9$%9bY0ECKtm_ zg7WX{A~aOA&X=3fHw&0M#X&*BYjNQ&`)yTZFbcEQg-Fc4w~(WI%1TF%c24OQW5Mkc z8lBt=tSV=~mAu=MJZS4a{kFrgbhnvd+Q~Gb^b>xRmpN?ZN?zCgGITwd zxcZj~0l()+q8RPY3_1Eq;oGsrk9~cJvnOOl>FgoG^ z>VMwrqyGUTF5euLmF>QkIgfB_Fm=}Z3K?iNs8I+#FX-5A-}M?LGq-I)lezocTK&Cj zc?|&jON5%9=_J=`dKBRk6>*44h^P+_ICH-9WbTCp`J1~jW~`6 z9LUrVTi&w2po45<(@y7aO1DMRvSb_h~xeMmIT0(;jHSTmR*=sEQo3s5a40$O!=>6L9!%o~veH#yWCM){zj z-&{$2BGR`CTs|MQjt_q>vCQt`W?oPnCkni-B|CO?>JbqDS+p$Df4mf*i;oeroGn|% zMJ_2t>T9(;eT77!#SR8+kAi#{jS1)!7V2)zTPmR{ynyi=)_#I$TyQfqlq2MW3C9H> z25k)45&k7O$CYl1;Q=BUk{b(i7^iy+N0fexGDB47-8N0DqrKZH*s@5*Lb-+fH+SHx z2C5AzRRm6UMKiJ}x&UuuLa9kOZQoy%2+M8`4nM0Li-MZiy-y^x!5p>8!NLJF)nEne zTYzX#dXmi*QnK@&7>bLL)uuLN- zy%K*i)?j}q5}T}?kF@>-anQ1T2?8MN{Ia-4B=YwG*%Yi29a{@28seuqx&v!LSLA5FvyZ+6$A>7WGjl|a z(oiG{Wgb~4u6dMt0zB4s5c^;Y`XjmSisA;IfUJ;L+ERZisC*){313h9vC*GH1Upoh z*r%j`GUOzAA3)LtqEOOY-AEpjFnnrp;FzNh4C(6vDP10NBU>q3Yema^W5ukk*Ih6H zvD9Oj^4S-!DX>q1B26lT;>WL;d^R-Y;O{~^h6m5w!44tx!JE$gAnrFq)QZ{0aa*c= zGX!;9F`ner5kx2d4K7DE@$<`qQD~EuBfT7<0vN~Xec62=Z&7lskkKjMbsx6S3P#v$6CXVXl4#+-8W1{17d{O!2QpJg8RkD&k+k(S{M@w8?Z$zu!Oo zB{`5c?){(D0{hBe46tmy`82=-6aO~yeToxO(h({@R2cW!IrZ^k+tOlp0Z*Z;&zp|f zAfA?$|F16d`NihI>V?nb8;>9|y0$5=5Gu~ebhIgcME~|;lO3)4>y!lA{Z;SU^1Q>6 zp_Ax^>%|J}uX~Zsr~a8MHg2=+y=#x#qH!xy-zSnmGQYie>7&$F9RWv@elXHF~#vWd_1u9wroL8jf)rD-r<`4+scAEev*PHLN_?hIv; zki7j)ccUnC!JdXwTb|xPBOVTV4x~c+9PN6O-)IxPfM~bvx(6EYj+~Y2`8+<52ngS% zPC`73iF`EgcD}A(1|nwpOz_EnVdY4K^~nLCO%ZL!6=%5qn7p{~a5{kU#UyJTcJcGs zjSjwff6o5o+s@#M5qVg$^EniFTz9%`+q|y~bCW-3M~d&gZ2>Rq3w!60RbJ#3M0)FN zE;oDgZky5AA6M#rcRi6JL$BM(Ze8h!zMm^rn8BSMw>083o9`2DdNKU4x zRvR}qI??qVhvTKmd_H|1EA$^{R3q{9Z*tKUMUKet#)S6_40sdT;^Y0we>fn_h4h*2 zD6@VhiX5MvKp_{QOz?EwGYd_D6$M!(uAaK@@BR7;>FoJkO>Qqy$6esv{-Wh^BR@ts zt^3>fKzQDwz8?wifwq~0{r;~OHuko^R1F?c&%>pRam-h>+!;>t;luoZdLP* z9Hmr`pWj_R*}Chh4LXy443YlE)yYVrrk>+)D2;L^l5V&)FxzPKI7dUrFT-`Kz1|>a z+v}&yhzHOi@>5~NL%0^uLT_R$OCO)dxI%E3LO>H1b~FCj zs>>}?cUfndiA(xr&m;amx1C7Y&uoB`>rM{5X=3?J7O+mPv>ngkV#;fu^|Dy0Q3?kI z^^6&ZycE{_HSoP;LV<~*IYr$2guCasx%=2rsf;SVz>#JF8o6u}OJJBu6kqmioIy}N z51jDald$*8NI@928lg`XCfh&A9cPPPT8MD4&=cJsPuxH*kfpR2)uNh9YhuwFgR9k9 zfqFaJyI>cnqH!}<6EM_iEHT2Q-jODab_AAN<|Ik`P&E&->}x_dn?Xp^9N0ZiFfqpD zNL>hR<^K?>3vCOD`P!jcYmfq`4Vj9i%C=#{b`8dKV{TZ7XY9)+5K}uc-0~D)X*7Mm zfSs(8lK#E25k{rVuozng&CgK#_Q+f*eW(18!=3m zDFUNVKo42%+02yBn*bj@$a|1LryjL9YqrA#iMXrvn;jiyOo3nU+U3n zsR@UZ1(j2J62rGO{>ENu3;1g-xqey76PyNJWN3v3SV5_e7PqN5j;69o=_Qo%AcvNK znSl<+EQ{KI?K|k6e)P*jJv6L11djQ3CR|g`K2$^-VNnETRd*{mcYeNK7DxQnw6#~z z45(VhP({`R=YEKfZ#{=QcJg_Jy?h~7=i%q)mr1$LPep-DlP%(td-5z#C73wa^Ws%G zfUM@2>V*%x7RtdKW~^#YK$Ti}Fbb|dC!rj!L$LGq`jq^|$i6b<$bhwH_@ZGX*JcnO z)W1>8BiQGRmId^490YxAmsFeQtLRwi{Q{{T@FjWw=z>OsV7mPs!)jQH6St3a8XTG> zBor(Bfgj48kox-Xb}~LTU7_$5?gf0*>xl(hNBS~znG;8u^;_@qB3TdId$_F@1fTa3 zTfA*&4-|Ep?DFx#62A+@t%ZCCmTR#)y21=VEMYpx=I5)B-~P~#^y4ewX4LtEd*|+oz#CeQ)%(IJs_YCh#vm= zbFflNEQ1DUv^p)#+Il)MpON0;HA%V!yjlaxHca7+@Ehl6-iODA}2WJd77F``e2N|bNdhf`kvWNG8wKgR8R6Oq_R zd1o5qBq$(vYk!LuX{7PH5O%X%cHACJ7B`Ro>3Z7-e;43F-~d4MtF)mI8#XJ|s$|rB zW)s9XGswd3k3x?#1M|e2!(R8((j@(=v1h2Rrd7ns%a1e2tY~PcE@hkwaPXmFM=-|% z_aOC;r`V4$2~L|{=XC#(;%qP1-XcY5)*o*xrm{Y9XUF2(&v^Q4)rQbwwdXdlUshPn z>6#IMnYA6q4U59=(_jCyVCzR?Zq_(S*XY81pZ|h&BsW~-a;k6Ik1lXm?RD7}dYMw3 z<;}tE%8TlcJ14MUj%Jnel^x~4#e%LCQs|6>vS@potJV;O)Ys{hXVQ5YHTH+o-@)k8 zKFyLjnw2*XG)`jfZZqtUX9YGcv&F(0Ab8}l}@Zi=~Qk_Y9t zRf!Q(CK<~>R?=BxdUdx{M@OszK9GoyxkQUZ#RK*WV1V}xM7CrY478+y7xcRURP5;T zDX>!e38*W?)mp;>L6YcDHQ(0dQp05YzeBb{EMJMdOpn<^h5R~1SfH7}p_J{BV>V}k zmUSB`~I_)>gK3|CW`g#B@jnQHmQkFjk}YB{mEnlr%%R(!&~AHWX-(n6=v2TB_BG zU2w-$RyqbTLj2K2MV+07LEpKO3zWhnkN0S%Ec^)*3Btxb)-<|>j-3x*zr=45JlF7>;pMyl7gU)d z4%4+%Xox}>KzVO|9hE%yJFLf}Wad27kT#vlJO|uG4@patJ*5hbf z`iKx848eBn{ad9*_8i4I;XgCI%;(nwRV`yp#`5Jk9VX{+v9j~W>2{p1K{-&6*avc)2`z+zA_?v$EV19hlb`f8=KMs$gSYV&Tnp9W8CcbGr#v8yo^Rf(J`It z43AsQpHBg|th}2jJQ~Co=cmvexacLwJ@o3u#|k!%y#f9$7rGsUylR|Rc&v1wOv5B@ z;ue?Hg@#z6kWfA(r;Rs{oBkO$hk^aXSP`@rk3o_y*oVSwX33jZdDhNm*Yh9tQXTeE zETXW6GN~xM9HrM8Djs9(qt93W5VPtH$ag%0g3zUw_N>wkcU zDDjfRZ@G2#3|{DmLj}#$)5Ow&f|{u$r@8rqSuyKqMB@LSxA_VVGd^Q{qy_~fGqb#$ z{nSURS*eSNaVbl=9J=tKR~CURo%Wr-Cg23hVRn?qld;PtqA2!lZ+*-he-RvlyZbbU zddWkKIppEdF75aRg9ijz!iXbTlT-L508j#hNYJk4Sd_p)Jn90}P?*FOU3+NX3KoEe zUmKx^!oRjU*FxB3EYO91Iy&af9+4+t3e8Jx`a`HaPhg$~a6vxFl)!pHwd5hA#Ne%f z(1ZilI(?gT1_Jb8cY~mkY00Z{@=!KQf&h*jH22Mg{;}y!Un#;EngIqJ z^0MQ$_8UrBrPV5svwebCE)kH>--`W8*tI}m@(Wxd)T$62iePSg6PuH*ss8x5$vXch zj6p?GVmY%8Ws(a~p>e{MT2DibO71E?K&&nJ4+9K_Co6K4*I!%xX=YhKnZH7Jcm7ij zRZU63RATsMviW2W{YqgqIYm7&D&L5#tM{e<-c#L48Awn^gHj&08K*=Mca!jKmdp)|1@4BLrB$18LoxZI-fPl$ZJ|n(pdlX zfnwhCZE-!>Z(dQ(|LtsD?HR_;s#lHh3v{pqtVoW#FSb z^O=6`k6TQEBXzmNWZDF%C7>hjsu9f7iBe&8+P}BXCI#)NY0)Hx#+j0tdYKq5T9 zkYBf73$fv~6X6LZjZ(ru{wn@YJ;ZFf4hU=e37>8n5oa2xD@t3R#gX7VEXb^c933fR z?rpMio;X>*0SspSgr|##Pd`x;+UACfWF_s#kcPFR1%2UW^v2h=7Oh=%Sbe!I!{@J2rGUl;;eC+DpGxi&=& z6lPf!^~)L}*(Kua#(ivE%PTTb`1hmLOWCi~_R=-T!So8fl@^p9Isso=6F=jl6Su#X z{V2yNp#0@UtU!YH<)M&SqT((I6Y_cvDvrKM{eZpyUsRGFR zI*1d5XJb`yq4o;xYP@FK+&cFOMOrer?L2n^pu2uwHvNaz)6TAE?|9Dd)iEg3&<@nT zs^D`qA?n9qlEzG^2YfFFk!qM(d6gXx`p5&6tdCN-Dz`#IC?%$7r_^%xlw$y6fVPaIY(1gT;+hM=v z+oKGwEqBw3b%GQD|r&F2!cA13%<9zO86Ugjk|%gy66hw>#~b9_fN*|{jpO!iSA-ySFgze8-mIsI z3f)*qED4@`B!a|e1e_j}K=hM*#)}SOmRyElGB+M04b}#6t0yQ>&jyxDqgq-1I3(kTg&Lz14oQVIq7et#RvvgW zToqxFFLmX@+u@o&84F~xKIk~SXM^Z_gsVvNbPz^HrCPS-5SlI4`a*7Fuy7*X9|f^x zR&#^jrDZuppoS{a8NOe)Gh7c!OH-SZj*{R+HjP1g`JayJf;!4jhExQ)Q=$qeg}aGP zTrB?-`~`mZ{4g#cQX{RGnuU{mvk9;LMR(~4P5tA{c=7;d=Pg! z=ewSq=&3*WUFH^$LH&)dcw{(|GVO4jkTBjnU}X>r)o$Y_{rVZ7)aaQhs8Dr&Tk=o2 zd%oG!Gqb0bw(M<_p*tY7(6N3vXKzizHq}mnp#%artD-O0VftVSdQBG$UCvbGe~&0`IQcd z8(m=GF58w`mB1VJ4szg>DG_f-a5@6|A2U9eZ7v0C#dX9QQoq^+pLyPAu$}Ps+HtDL zS&>j-F=C-B6=b5;f!XiDia6f=6oa?753R!PrGo>AR>D}mvy*F|~rM|F|Mq#I%XMZb@<&b zl6o0h5bNnVZ;sbfHnG+AdmZQ;_L<&v>wNaPcDA-s8a!mr`Fbc{tdJFvC1tKH0be4x zvG98n4)>(;+d!2jdVN7x{5Ys?H}DNia1eStNNm|~#ebm^HVVJds!-O>KuGr2e%U=% z6oL)99?@{!stFU5Nj@_a{`U5Cy~=KV?t0JWi2fl~+wJsv=m;(< z11n1fNqvpXaC<`@C$eAfkf(v^O62;o8$W#A^-PqJg3fhv*|H7#)Z*Lkc@nUbyq@Fh zdB4OMMo^E19K^2Ta?+EiXuuK?H}=8HWf))G{q^S8by08Fy%Pw(^)jT{d8a)0B(k1yNe1cg9`!|1{dsZX*pR}9p_BBisdbaFGuj36??Jk z4%>~Vv~sr2hCOjQlNvSK{rH?d&=ZgTu5wb!TfSvyc8m~uU0-eI-L8kr(*pa^Ia_Yb zvJ~5I%`E`|U1x<)u*|}B+vliNicJrR3Vg6v)DB1Sil5yWPkCYX|ei=Gm7m zP8|d4?dg#w3R^qb58?$C>1<0%n<4_FOPFLk)YhI#db-}M#%42PBF<{zp?+rZ7mws} zYTRR@DR{$${%VNlr4a55ngrb4)bdfRCFS##abT9eM1JIo99U;s!-Pv9Kv>7}jfukW zP_iRp{vn;yP)UNSCrJWu$t~E){)~MMg0uZ802NLP18AkPkhDQpMNW!Gh^~EFf$kO# zkWpDFoOX);Vz7m1i|h@i4l4RfYbmZXTDveK1-O!djM(p`6HdWY7sFcsFhR@&7f3fY ziPV+amnPl)RH~FLq7&W($%~ZKDSz^6q(x(QXb!cTVN?dI`=!BCL=&&{(Y9iGHVaa> z3o#OUk*RFNELo7TE{$FA`%F%ChM?$8hgkrLrH-Xk;j#R;+!2hpDH$>0zh8LATnZ_k zjN*|Ha{CUrN>p3^#BEjR7iYneCTs9=#u{c`>x)L)8KaZs&arSH`f~J)>t5&vHV0yATK^%`r$h)1V`j5O3dHYq7No z6zj{_5&KxyL8CMO@(nT2hRpa{I5C{xuBntdNoA)Yh0PQpMKcGvJ@VLUItM9-rJ7`0 zKmy+5SY#~|g2+SxQbHfhtNKzVEg>DJd(sauEYfwSLP*YfPhJo@L8H0gIc!YkniQ1h zKnpviRu)l;G>2UM$C(8fi*|y*T{_FG%w#IuB3vI(!B+>ecgyChsr(Gnauab`*88pI zW&ZRU&{tbfL!Q-U=TpV%42)*MFO0Ri2wz3p5zQ0&+jsUgdpI`SbW^-aL3{f%32wZM zjQcsJfLfb_%v5W@RIA3svXR^FvlgzO+vx0plY)>sPZN6?MbofVbHQr)Mrnp5vfXk{ zo5a+VE#+!~$KTpwveW1iJN{>V4Cd6*_i=Fd3^GQNgbf(==GPX#cf;nTC$gv|Z+~4a zyIpjnqqYC}anai zm#gQF_nNQk@6VTEp{K-khvV>s#kzTWzh^J*_`2J5raFZG%>vjkd~ZrrL>;_{7$)w1 zVb56w-YHhmZWa-_y{~Wj<94_~oTT~CVQVSLcU!94V^jKjy-MH(TsHH%orZ83ZmIq5 zs2zOoGXm;Ly}_c3OgO2C^Hs$P^cbEz1&umvs#~9c)_{hb=^<+0W-a0UnRyKTZwxu3NCK3}d8K96Sq zIs%e&=vG8b<1cD2&**Cez6Q`o(_nBNH%Y_%jg^r?Qs|T}4h=gAHdwH595j2`U6ub9 zQ1K-9g1$3#z3g4CH}@{LKQ65^b?-S@*IUIU!?iMZyR7~kK<(aL-s<5k%ziwQ&!9yW z*O4veZz!9*s{o2oPp7@1t~*Zyxq;`*uLt~P??gjs1>RRP5$=DH+wHqhwfb~Q;DrW2 zo*tL| z4#y3zp1nRm?q`CK^?hG8s}#1P6nJhW;H(Wj78m^|Ydjupnr&CT57jk6)%yh&o+hE) zcgBAJWlaTf0#2K2P0V}`N;c8=2#Ar%ssl$Oh|*n82Z@?hYtF&3!bFeE-+jm5r(xXN z2ruu+HQtZ}ua^#f{ELp>l+7_wCdX;MUGbfJ>NR{wJ;^HtkkVPOXfBA_!{u-;Y@%XD);|%Xc`RpUGKZ!+oVm`Yn6jHxE8VR6%b<} z<tC$WxZr2hb&+dq zK>ib-CD<

6ugR+?WL3NX{Xe*HDmbsxvr8D6gQ+p%!vW%eQ2Noe>h*nN>P1ST2je z+*7F6x*LR4eN`uOi0wpSad4(gWjyFZE=@AcBu>AKMB-<7=va?p0Mf7rN$Lys?#OaK6 zm4bFsg-EUZf)m%gzq#T&LPgXT0%n<+n?1^((E7yd5$jVZ8^CXp? z(r?iBjjVs7p5=aWQ}n1S+8JX8H}F6Th^v`#7!H13{&|K*o_5Jl_tZXq$eEd-)}O!^ z7QrhxO8e@iWdBH^nJ2uOW_bUV?ER%xf8o>V-)toRh@84@&F55?~ItvhbMKwhd9L_7H zdH4lbdr-0hz0*WXGUNNr>t4LU5)QRezjM~ON71W;6lweFgkss>$AL~CR9k)-_F$O4yZo#9oO9g>t;5;Aw6H} z)hcv)82om;mBiL>WI6MrFxY@y%61wXC4Sj~W`3Vb5qi3fL-&{*M}df}Z|$JFJa#v; zBJf->{@x}0Ha?yrO!0n781H^l*9z`X@%`+$XX?0Nlrgv|=H@)k(EJ=<{J!`5!UZ9& zCDbebe$nlQObW%JrJ{{=T1?j5YP_ZDkuzcwX-u7jPg%PcLl@gVEy+

iaQnV4Y9D&%`E3bXb4ChRFHK^`i%b7PrWj{DIf8 z(YLZZeA4+Ys6yvh`)k+Tr?rKuOBbD!v^oK(qveY8CD!R0K$^?lUnIL&6=zoy@7{S* z-TL#!x0V_MPTP3M&dCxbQMei~Hc zww-56C*X%*W`>FM*+{`;Vj%=KbA(U*;*hJCMRVr^{ z$%;`v=Wo{q6iMdwSj!<#j&c-2aRmiqV?9*JW=sGC<$aK zjQh@r=hyo;394trxI0R|vG5UXnUW#$AZ65L;6XYw^GHKPg02+NTKis^C6d=|@j8qm z$B)=j((X2Gs95O`7(Z3hAI5y1n8L+e_(D7>f0LVMAw0<)L92Ij3r5L)y%aG}g{k5R zj4_>*qLjXf(?yjwIHNZq-d-%RZ#{QHX9{P253dLVF+5u)>2zm#Wne0vq_H6t`EnEw z0&+rS5*RSc6g*jXhZukIS2(Lxc29GWOBYzl{z+L$IgqwFA?mH79oPL zC6%CC!J4UEN#zWPs=?IdXsYK#GO&8o(5B(-X*+J#|X06&qloSYG9hB1@r z?`RU_`SH;NUZF91I5dW|YQ`l8i}vYqDWVdyR&0%7-;^NX0x3G2L%Qd!&&u;4-s>cN z9nNx>B~7n+Gi(TNJPPIi-h-s1E;LxTm5|c=k{&?Ti^eizYyobZ%4!F=UQIe*ZKR`$kcBFOX!-*?Rm@P?Z%4|zmAi3r+P$xc|TKKko-rJ&#lNpJTCj@q(ZKpuZcU-{Z2MExX&> zlm3W)`^LQf;c}cI)W;Y(6PVOHl+=zm$xd65}uop*S8{Jd7!+;)?+jNG+#0 zk~y0;2x&CkT#j%62n9`wSkZTCyBO-WPf{M#G@mskd`J?SV4SM&7YdH;C)+`ye;2(v5>4UJ#~$Nl||vOB}sruXiuOJssV6FK8W&7XQDg}4c_ z6^`D~joE#FR|e_Wbr@tB=vTXh@GJUfR~=Qy^21zZ|2W;k@T`|D_=7JmkzKic@jHLF zf_uhHnd8<_g_v*3$npPS>z$)3iP~-N*yz~av2EM7)v;~cwvA5Gv2ELC$2R(=zjM!h z&l%(UzsBBctW~R4)$`2xo4;m|GL81-t6J&AJ&+E)$W`*m`d=LV-gCF zE!m}1m3GU{n3*lTga#u!`pBLMF<{ag(DTHa+kwnds!a7UqD`)F+o;@w!Bj9SWkaqN z!|krN#IWZ+gpR(>#uD>CbP5Ae^lPAl$r6{Krr$qFQdNcx+-;n{;_3V_0Z{UJ43^?HKgmfUCXT`=NM8i$2bn{|siQeYf=5Ia;rz1A;Eg_1 zE=h7nmw^CkN)QPAS`EqCeukO}D|C_}u03M+?!9$CD?bcgPiQ$dF4Kmo1++rxo{1XZ!mcTTJ8o-NMxo_S5AV z%}LUd^z=F{d`owhExf(5PS2$#-nE!*yUBanvO~jqOv$f}G}nx);TFi24cp$Q&@7<= zJj*DX>f_SZluMnB9xk`05m}xFLs`pf^XGNb!Yoq^P3|FG$Dqr)C7%_!Yp(w&*91<0 zPk^bDyg$E_%=^BjFy9hT4u!kF?;AazCw6}C)N%ZB7fY?ot_phlUm$vxXoN?5mi}QLApIYRwvyzAF$sE%PI1o zwSOPi3uiBBz%zd)R~2#}x5L z5#*e>AfqGv;LBZ#&NjgE14RrI$o^=*K3a=VSpY|Y8}3Fnv%yJFNf-r#Q{?k2RBQ^{ zV<1u?W2Aobt^tZ|c?!sb!Y&+n!-hQ&0jDcKjx2X?eLr4H{0EWwpaK$=l%f^Z;6%Mx zWP8P{c!@R9zm7APjQq=feb?fu>^JXkRTeM*HAtIqBs@03Vg)sCIabH zHEmo?irZp+Vanpw!Yo{Z^L<7MxI>^kNn+WM>muz_H!Si$J(L#rh~pfUA;RvF0l0+} zhr0%v=6HuOzX%mI#w?&9AXT&)VYVav)!JK9i=c~z5)JL56a@d+pZMU~pyo3g_H07s zn=Qm!D5hAB8ixNttRSS2jVgKv4gQ@rm`D{p(0-Rv&lMO7%Rfr3k;@MpruOi1S;R4=Q8=`NQujU>G9=1Yf@6lg0$Rjm1udF zlWlB-c$mEE0-My2!HiI#UQ4bg&E;`j);pOFC+un>Ja{1vpk^h#b!^Hc;J2buu}uEm(@Zrtj?#_0OY| zsn3xzX!W4u%U$7o&|9VNse^Y&=Y#ilMw9faX_}M1!8qN-6-Vej9ai1LcT3``$?*A$ zQjND~C=A!^GRdPxw7J@ETY5S-ZDYnt-bc@>DR= zrBVqqa+=T@(3z$gP7rY%%Fs`-7xyta_FKIIBGw+j0_P}B#=VPoQ>|uhFPL3JZVs+$ zM1SbyFR>a^Y;%Ks+GCKgT`s?d3*$yuRZJDm4#-EdUKDZQ#6&Z}$xe~v5c$S-D`lB$ z>KVwCYY-Obb3Q{3?9cC6V5h~GcSg^O9+R>*bTKSg)WJs)^7To8OU^P#IhGsEwx=xC z9V_vkvW8z6aLdG(GblLIzOjNt|4zZw_9tVYQfxvQCnecK{lZbm(#)<(Z{Skxsd!V4 zr&F2`>fj(vn*AI~#625|sh1K6JWymVoS5uNDiY*`S^!R%Hh;3m+?VE164y zvCp7_0)mc05raOT9Bdn&bvzufYG8h#=T4YdF5vRB0bGv3Ncc%qIf6tW{+_8iUbx>d zp}BvSc;1_7Ltr2?IU(`ucevj1G3>h|dgpj)McH94k%xo}3zKo+-~8&KUx~#@I<2}7 zp~g(1WyloboOX?I#nr*gAYhEf(TS+#VXRQ@P1Q0>sL)Pp$$5jCpf6O-$E1jNNq&3C z_YC$~97y5>AWL%@PvcnVtbO_Tv4j<}NGKIoN$mNVkd7RP2u6}mjnCMcCi#wM03O=> z8gka(8}nLcqoss#Fk3H?&C};U6CZ1ok;CPQSqQ?Ero>v*G=T=EObbv)g7b>Veq!cv zYd3|?97M~<0NOXdOKb&yQ*N|G5vzDFomRUG_tyWL1yJ=d=y4Z%wfMHOEVa+R{AA-qJ_8f^f6z()CX60kS4ZW5H`cz#k<|4Ydpqgb#HmH{?#O|$K7&% z@t)uN zRr)>JH@el#x`$>L(>9M|S63k~hZe9LS1w-pP7S8_Vc9OAgv~U>&cC1h55oKpxzC+t zpb66!x(E2r-@Q1coj;0#AiQ#AU8mXAp1ejbdkXPHLVUo)!S5ha`EDkTV6H$?{&eQ} zhyhB-%u)j-sz;Ig`ox$d+8C}nLIj&qx@bO-Gr&MKxvk4I>%Sdg02{@J(LmI3H4Ih#x^8$ z6s17n<-pVUsevC3t4j``Ms-}Gyki=Tg<#pQ10WjnQ-!DIu0$4uDzBA0aYD+NO4O=E zf_2kd&TWXK3>H)Yr&#~E9db1Bsg_!R0{9rWl`>%@dI~I=Sr~u=L+ZuczcRylnga~Y zcG;2&6^KEv`{|?gEAgpFRP41f5w1F*b*0D&dp$ghn@YPhRl2kzc z!-*0zzsS^NUG`6KHRIEQ*d?{#p6k0zmJuT@(6%}m`NrRX<-PK=fi$M;InpI;@Uw;h z<A=@aVq!R1iF4vcwFS5S%4On`j6;R~W!R2seWzwq&c z{)d$}M&^tohtb~q@nUB^sgtBvS;ppTUQbOuJv-OWui$kZAyvUfO=f9J>sh&N{aeJ? z){EP2m-SZMZk%&ZjApqG?+>!-X>yG$*6&LteOH&dxBk4>r6p=zhO(xetOM}&@Mhar zF|d;^5ErczrMFKUH1(DPWO>$WbF#lTe(5gA>UXVoAEr9y&|3{)Z*K38FTG#&{Pp#{ z2m3fG_Hy};>< zrCiz2ruPkf#c3VEo!<2%`gPy_Q6j^Hyw`_3-p3}tPKN^sY6p7XBN76S>+b;HApV!2 zu&^-Q*wHA@maJ#`-2h#O;$Fx)W0qFDJQ)dAO+sMYLlhChP#GoF-~dD*)wobr>|w&{gYx8QdKv*=bI=s2M4rTkACEdb zG-6&Jm@0&Qqx1;#ryRKh^%|yXPkG#&959gJCDQ1?5!loQD5KCO>$Mp;A}%s&!r<7D zbtH19C|=W;L-UysFs!jV;zZ`b!7qWcx)tCwLvvti3TV^QUP;zu)x6BSg0Uknr=_w zf||2<^cz(d;kj&s{UQPW{nS81xF-3N-s>a@0!bARv#z8N{SGZc_5&AbJ3QX{N>Nzm z+t2F9VL>O_6{)pB29s4_Ao#A0F7gwfnr%OH;J6hW)k(3JX}hjC?lzb*?C8q?=GOub z>HVQ5NOy`PKSA&jF}(jTadMI~BtxR|)rrQY#?sZeicNzr#a*7>--=TY;Xm9ZWgqbe zE?@^327d75Qm+|zTR7^kDp@P`O)mYf%2}O|v&T zCMXK;z`){$rAEiUQFavBTas)4n7{qdS+Juc35G>fprD5@vdRdn$V0Vnjhe_lxAx$) zBb%t^W5Pj24ldXq64JbeiGV{{b~JT{U}S9pjN9 z^|0UlRKymD9rv`;0KVl=?w!%hE+1B}lms)aqip+=&^9zR8>h5QS1X(j?GD#Bx}2)i zR!HN^t0zmZ<)_=rgyme>1NHX}uQ^xgw&8i;Q#*Oh_H|z?aa*l2-aESH4RQBF*0_(& z_Z@bLHz()7t<9SLw zvSsAt>T@4*`IN@oa&Q$xy5XAlN9?LT5a$5f_Kgl6C;9UbMc(^ZIj`Fx$?w5vKkj2R z&iiG6!1p_CK)9`dTD_*2(y)4~>b{qH43zs|w3qy{s8?Lbn}%?}ZgeBub0VLQvLW6m z=HO?hAW_$zfF?K*7daa5VMrrECq(h61gJ=HEND3DXz%1R7K+)r%n2_@9^x}guCL^P zfMi{6=U929ENQj`xJ_@e@}OOSBN*{OFfo#_Y!kYK^5%KjP|0Uf7YF$^j#t@dxO4D5 z4bg9W{U`y=B17L+_EX3sU+1^CxmzDePP-4YyD!NEk;g(9c|l!Gu}-H(%;3PFLFFP? zT-6|lh;3Rp75Xs$KMM!R{k*&1#MJ6oh*DYEN|7^hlpC+a=h(=fs9^)?t2c#0rsQ;# z&8U)`;S@{~Ry(y3)uLsT82~WoInwOrn}8^fi?lLAQ3tJ|XcYZeuEg&E<=^vB5JN{_ z!8%DVr)LJfq`7gFu)-X62fRt7`iPWB%yh$boXc3T)tg zQT4$6$%RNd<_Go83o6cqT{!Qot3TpR8@*8PiMWdwg+?#qHxm?*j3lbL@OGK~ zF{GcokVGf@d_XMsL-S>#$(!*wiMZ5Y^5w~IQ}zzK`yL3h@ zv*iakDmwYt1B6x+DCI6CjUs42x7q;)U`QrX8GXedIBc=?{azjcT1&0C2L09L9-pRG_9NU+qO# zhwU<+gD;k>RlO@*f8jpCA@*=_9AD9m-!z7M7#LWGYh$!gZKa!g#%Z{Ow++ooNU8FN zA&!rV<5sEw8O09jW`#R`jrc1iczSacla;nA3sUClSc)qV$nY!PLcT`iw{KeR!X_$aXDv+`O_j-+fQJ3B$5F+U1g^wX#;#ymjR;Fui16`;)xc zYyFZrZ3vV5Ycc`OpqSM+t?vZ8!NPPU?@1Zi`1bGDbdxoE>wmg?AJh*5iGW(Kt#f?O zKLx53uj+oED}Enyaow+wzCaY;QPQ6u193c?Pl$dgb34cEovqF_*W}VwYi+KdCDJ%- zon;+)?ssD&Jujk)qw;>(U+(9}Q|BwS#~^?^LD@61<{iPuSqmVO`nz zD#6AghDi@diwn_39|^Zeb6uJ!`dQC$eUXAW3Y*~YhQRZ})M~su-M4=XU=-ettNnPd z<t^Y&_ry;XhPYm z?EX4}6ZN(J(76))LnzN0&^6?sn-^~Qy1NOi@;mmM8Z6vgjG{X!Y!?t&7mz^(mrd~# zN%#lDs>BI;f64HSMEv=#f#=Sjho*xpM0gA!7GJZX;G_v6{bVXiAK$`J?%c8kkO$H# zlQ@&ThnSGL~wb0s&0 zsm2=={wbcG@C*4790(i1ramkmlI1i_b<%ux7h$2v ziWtO4u!o9HiD~Z*&QJ{G+m)_rS6Nht=MJev@UAMjc#h3!rW0DhH^t%7P*=_c!+wJa z&+Sn~0EFY4`kP0*!;5u(0to3PNT-+nO8ciQRInr`k08yS_>pgJ)N(}H>?==NqeQxD zpJN`MNAP*%;DpnjCJ``;B=?qvu)-Je_`TffgCfbF7w?b2ZcLtFo>Kb~x?o8z{T zR;`qmxvm@VtMjWY+xR1TF5=o;U#X@WX`9tgnU}y9^(@1>E}u1JR{v#f|G!FWGf#ha z5J*z~?^HdWua|@5?PKyGRf?|9v~Nrr`QMShFTFkQZv0QNZaQ2%hksu)YFW&x$6MOJ zE4U8uZZ1t_ye(Tl$Io_~a^Mb{SXLgknh30DEo#5|Vd;Owepbl)oH6IUsbmQyDjl0m z!NN$2Je5eNwYOe0@3cQr&wZb%5*8KZM~duaA)|zZ#)1qCyrrounh(1X!HA3cCRV#G}Z*3k2?U)ZK;(YY-*)O(F868WYdu1}IYz^-ozANVzl+Eu|!j-jXUM z1D3l-iW2o*yBSHsj;IJi`!QM6ryq}dl#M5CqIWT8$cY1M*?%-8UmYK^!byz+rlIsw z#8+|hoQ;UB*(JlYW4U- zXAPI+Yna3dusBHZ3TuleB~u2`Xm`ws4AKqUTo<#$iOB z7f&3?P$=!GH8?ZuY-P$VDKLrrj2Ak)v3$w2Z26vU*#P^sCf9BN_~oVw`K`8y=ZG?S4XiUbl2H0E;|dXGYze&!ptP|!3r#k zc}949G<-3Q%aK1n@HJ|rqNv+Bnbbfly;Xe%Z8dhpd?L<3kfX-D*0x5#uo}wB5?$3) zLj{TF$4)QJiHiz4ND!PnQKBIuWKN|s2ira7;;DbH=TEdK4v>(f0hKW^&N&4E+&rKr zkA(Mn_FpAk;A->K`BP5f4rQWo%D8Y@jB}fd%EZ7{m!>84q=fRrRrgY3becbg7|(Qp zN`Bl^-!%o}9PXhJOPXgJxQSLmzw zEv-#9E~}MO(#mR(t@tfz=LLo>8tsSM?V(xoR~9Wf7127E`j@}8Ti&5=P$i5o-J$tx0ikc6kV#H zR8dN3{OgwB?_6oh$1!{NZH`&q=XWMNn7G$&ALP!w-ujM*I?oV@VaNUMv^g^c_TP8Ro;>w~+Z>Ew3<#hO|jLIWc?oIk+NTFc+fW zPCxH^*|b`ImV6rM-MOra{1~?wxRl7$D{#hl;hnrev1p|Ko6 zQeaJ;Y@9+xs4rf^1iVrj8!J>*VCe#%wOd$%cy9@H>)1RJ8{&y{%0OYW#qF;5#<`ag z`!5t1*+vKn1}pF>V=qHQ!~?jjr4QS$dP44kOlHxp2)@)R0rySFj1W)OU+-y=eSCjv z#BF8IvM_cz`xGM;g2pWo`$%{cMSH=2kTQV%g>Lg)0zNXDyjD^Jbp$>FGTy|>rPLsA z7agpFqCI7N6NLV3&45GhM9GQHyd%`Z4LwS3=i~dqNn?p#KH|_#kjl^zw*PAV<3pL} z>!%YqA!bRJ$zfWWPvD3sI0etZeTv6YMMN$cTB%esg`#;3#VoN3LM5aKy9S>LfodHi zt1FaAStM6)3v?FryzYSD^XG0N#ua-|)sT5lyv}bmQEOg&SdJn~((lUn7I4(y@Btvf z_13_WEW)@{&qk%h{{GYb?ybw%oB3CN5>rMp(oAu-dYC5CvC|b3j=;sloG;r&5_Xcj z(vRmVZ?Sx?INS8w8#EQ56?jk+=}>4H6I=q~i=%|1KH398*6cJ#d7(RV0A(pyRdEff z*3ZM*>fR)fQW5!yPr{o#Gu<_SeqSwH4Mc&F&OYJ(PYr>Pnx7FW=>(brJtbTt&O&F5|2a%&myYz^^7^|k9mhj4|>4f5PHn!a?R+IEU&kp?l4w3h~9jbRfkOs%IQ36cs zFB!&7lCY6sp#4O&*A+)euU+i)!SPqI#P1A2T4j#$=S9bE2IJzX3%!#SkZi+cq6fpE zFcT&Ok%?Gt1RMqLmQ16gZ=3rQyaO^N;#s|mmtmAFEO??}5{wJ()(EAA+L)y@LFr=N z;FbV$-G7Mp$nALv@sthh!14+al2#Qt!ASdrt2H6+HNC6&v(bLjE>B`hCMF3ZD+qJ2 z&DE>Fm8C6E9-g0C{58+TpFyrX#1uJ367o9W9^i@c_oo66Dx|Fm3wjmy8i|XMs}(^T zgRebNRCCxpzOzrnsH^NEMnobi6lgqYrOhD$tfBS~*O>!s@G?d8WWX?#tjY266DJaE z#Q4TrwF=t=3OlMJ$)X-u!HSZ7^PV34TsR9IN&IQW1bBnC~(`j>!_hfCKEi5i#+Y)uVX0n7AC?BF%LEX z1aAf%?h23M9^oM9N@sc;%slR#oPX zFd+(ABe*G@z+%uaf-IisBl?K8xak6Qlq6K613quv1J2}lD%3s&NHrx+B!>#Z>qT0% zoU$Upo&1Y*B%w=g;uE?|de%FV_mlvkE}{vxxB&>Mpc34ac7v+?(@B^HBS9Isd2p4y z40-BT0oVhX(W%VwdDF;oIJUU`s`#AzwTu~ei-qw9a<=&1x1zH;0~msnD(S#~k&d<& z=L&W%SRM3TBE#9ke(Oa>LLLl(-h!NY!m!p9QgXr*Mk=D6XWWMlM4Obhg2+-|#9PYu zwF-=RsSLkDD?WsCtu1YG=OSSK3EzY={1iIQyRq2!1C3{862YjD4_-{~b+jB;f7F}r zNRGTtvP?Ri7xjb}7b6*qIZ)w}cy1>_q+vu@UI$@q*3n`OOqpTzbqVn*FdO5~9(`&Z z#>4iA^Rnph=-AoW63;!AcN2Hjbec$~O(ZI$CsDn+SmxP2|D3ygvF4((HW${9sYtFH2`Tm`roLq@Q zfu#Ym10vA(;chX=4#IjjnUc#{w4OmKk1X-}z_Q5W<6zr?0?K(>wBE956tSaLQb0++ z5S1bX5;2OAGsjlk6M;&>Hs=FPNRj36lY--hGUT9WQ`Ew9ju9n7uSsqcMlpyM8cGFJ z*dG+mh<=b9V`?-*$^Z2EgRKDrR)IFgD1nAl&x_ol1F<3>)>3*{1XieKp@7W90~ZX^ z!1b|wx;N%td}A`0M@ajSrpUWJGhgmWC*R|c2t-US>&F}?qnh|ZOl9o- zL$V}agA1)O>l1`*=~&`Lb%@>aQ4?Y0h}U?JpLZP6xsM2fn0h15ASQo)L2-76Yn3>9{`j(wO8z66g3$vTtc$L*u{(o z<>{>bt%cIRm#4ry&nBsdBEmBYR%}Sn>?zntq6VU1&lGfYDM>pl?I%zq#b~o!xe+W% zikl4V-_G~@`mP`5pnt(rtz;qQ^99$|t>om#uuZIXPiLs(tvOt6@C_jtPo9ai2zAV{ zHEr2#NUL9FYQxzd=nm90Sd5o`m7G?E+MNA{Z8!aP2!EtXOS3UM{^E6}qy2UX>Udb7 z+p)`X+Wk!}*({YdTWbu#*|ELgNH<(YllE#WVf)oku_LSVc$1;Lxk(FeR)=Gp=JM?> zuDOJ;v(&IktEp7}Yj~mep9=9ecq+&9yXxd~oMMV)H1-{eXN!{d@%i^oU+;c>h41EO zK!R-N%kFFHXkfre)X9$NLGCNc_BKG*Ue61MOX^qd=)Kh#Um6Gx?dUnY>0GIv(bXB_c0<`KyP zVYViu7leE`a8q{c+L5uQk3(OG3CIMROU75v?x;W`}2D8T9|G@OofiUgEoFf5Q6 zDNwW*jR)n^=yK+y5AK(N)+2<<&NSkJD8E9sji#{XK&9k!;E=n*=1!Jg-ptupcnq-p zY@=Xy+}n`iHb31vgDe%FS(6_R0eXlaH{g$OQy5M%(cB>}u3>Xnyc`e)xRDS`orv|E zUTwDH_)nHxl#C8@u(tkEi33kux_LrKaBT0;<%U~jbOkN3a*Q}^3*w49cq^WQZIw=yAHvp!egQ z=RZy2Pj*9k4vl~O&X-d#Y6uiKx#T9k@~!E-l9Cv>W1}W9h=Vy1EEm)26f5C!CYAqC z9odwBv$x+sw#^0;y)7X6$vF76ijm?H5XG|Pk~ASA&r2Xl8IOy)c#g#0Nrys0$BY)m zT_p;|0fkh`(^zdiKU5`1y%#uQd~IXoJ!L6~R*)SL6rN4-@pq&!MkE-#ZhbkqM#DgJ zt(EABhrlF>6GV8d@Gn~v_>wI}DP*SnD(p-7ua?rhy5p$DEjE z27^Q~`D+U=OW)=HB+A?zvq+rr;Rj0!T@Yo_79st;v00j#4H;d4&n4O$;;;rAQbZLo z;QDH<$eb&*XYI|uxX?Q~X|PtPhAdNE_<-^Q3xc~gi-#TWvVEL zx|er{&mM^z%KsGE1u*j(4E;NF(+?77_}*#8NOSa%#{dPJMRQPNfc0J$$^4|hPrTP2 z?{S})1_&%|c&|AYz4JgFneW)n+e1^&{IqSR`>@~p^);oI=Q-!3W~Wt7GSANEKC6W= z?w<3;fi}?ld+&O>ySqQ0sru@*Ihu~`UfFzedw&0`BIjcia@Q6E!T~Ep<8+Q9Qeoh@ zoYAPNHks~~z}H*E(``ggt!arR^0>-=kC4HPS{3Kf{9#}ZZva13ub{+u)qpAxW z8ON{|ilc#Tiu59QtFaz_stS=Nh#u3|LkWO4sZ={F5vmVT>^D^a36scC2YSDc8TIh< z%An&KMY7T1$YB&iTSEbb(U~d$%5h!ffPOm23OGj~Go(Spqk(u1Qc-rzE4*a2S|~RA z{?0qBkhc_yOcG9qawCt|hT&Ec5W2M|3Pa};?uOu1_DMeb362He!Z{@#gN!hbg|-<^ z8E^)+Jx7VgO{7GL1Df0CK_oOH+FhcLXb@}Dl&hAsE_w$t4YDP6ynu{?d_#dbi5F(Z zlr`%zj~8NVkUlk{&n@GVq>3|Zqyu&(YEXHykgIx7+ZHdEXndKo7D#?u8dwNJkQa}Eg#}I5e&n+J4XLRg_Wt(fSkzhR&tC`0I*l@^_izPwI#GbC ze=+ATg~2;T1x1Qz$dzF)S=NCAF{=n~%(^jg^oEF@-1}Vt%o49h3IZ^xqShbvo508t z^(BdsyKD?c#!DCXWBU<{e(Vuxb_fa|LVB0t3Ef@~k2CW8x@t!$@3nVs$;`6O-T@HX z1|M2%d!e+nt*$qvq8FPbC8^tVzk1%5*DcI!)_q)3FMmT8U$Qk-*+Hx5~CBJRv`8~nc`VIWckpc?vdp`S* zPALcSDi?BH=kC(ZkK0qS72sl*>v=}0B(vS_0!ueWz;g$b#L-bi|ddpigsL+cse$9 zZvP|fzF%BxF2J$s~Bx%J-Rn$0k=3gCT&lRsD5IYbU z1bt6@0ZMk4MABEcOJ%6g-;HSug%`oRCLWJ0_x+K&gBpAMu!=|lP$zRLXzFZ5tp0|s zQz6Bwmqdw}iQ{3SWdvBooTrOK@~LLNzmP-2D1(+E9h2N-cmeGip`at4Yh()D!8_u3 z3adfad(a5?)lyI>15$nQULUI^$cOw1O(_`M9QYz4MS=85u(fOpvJpyr4HhZ)-Z#@C zSq)=JF_8o*u*G_gA&9d`_l#Amc;0xVOAp~xS#!t+(nq7-n$ipaLDaU+sL7xH@7nqK zA|w&~9uTnT3k?DVBs`Q?x~j<`5v@E{<(2^b3E zbhFKj_(FlEYjOuxUQ&qlLFE%)QE0q;<`(FD_%{;%GsF8`Ve#iU`TIzyHy}+bJxGt` zYjvW<6jjbGJ+3&S6jVd0dJvbZfH-30EOOHxfi>tFQc%inFzHNfh*Shz4?E2Mqzjse`_eX2sZGujj{q zP0Wua0Mvln-@^`k-v`zcRu%k%>P=~{n_u_e!f=O1X{9t>x*fHwr>B0tZxyw?A*p+izPpfIxNXadx@P}zhqXE(Hnts_FmZBn z0?QWWd_1-E$cl-+zYVYSOcOlS_m2T&>2i?_ zqf`)ldWUg7|9)ftV)gol5kQkx(}_=!CLUtBMF^~iB5B-0A0v*_P-V;thBHw>sX%E4 zsuSxLk3)`q?OI7Rx=XhqC7^bPCz&dp;-Q^eVny0RgOyl_0ZdWH7u^s`{ehl=W2(*n z6z#=RjR2%}!MrLXQTWM{<#t6Lxkqzryr?uty*2B@qcP zJ$9rAs1lMTK!G?(9eqd9AT34%2T31oZzty;PkgQwMpjgTiucE54+oMFNJtHtaUjo& z7LlaTdm(O*2P9E9k%5vQ~RV~46cV_Yaf0p1PSVxblu(Ivejn^9y*iPpfY<;tn^ z+>a_)g2sG3+!;FNA>$Kkm%I29hIprjAeglJosJxAV5ly0mYKdaPfuzRjp}hI zpT{Au`+I{`i2WYSccF9MspMs%Ofg9`MKG=UXuWc?+5(z5SCk>63yOk*1@CVQNjw?$ zZ1~t3#ZtH7X#K(g(OzH=S&iu&0l=s~dcZDK44nx|0B!*t8KJupB$q4*0bW5QdzFME zAYhsyky*p?!`g%r^&$0qY$&z7fT0p~`Tby-BA^=+otHmcNuHZ|ar@x*C{vL6kg8P62a*CYhdJ}AZH=8LSAcgJe(y*iN*?a zBEL%&5L3MSS!&B*2TKI!PxxSB?m4F9iz4?W+(Wq}HX-;Jy<=w+6UUOw9egr*UT9L2 z;025%KYM`hIf+fO;6q-f8bZcy^lXVl!^$HEMxY1|Gz%uZo<&_azffj&)a$d0VXX90 z39C&}YLXxZHaA$S{;aLC8BD-zkk@QkH8~~OWW8O&EK{vdLttOcpV4W1Hrp+38%+>T zM?Yb%oLXDy>u&HZVE$|~U8&jSYN0W`{0++XYZ?Dwx%jumn(V=blTAmLP)8S+qjKC< zRMoWmN`nQ=C&tDEU8`%&hR;Ol27TNA{Yb6~T6wzY?lW!1x`iT9Y zuW)3>fTXD|B`}YI5Vup6+S|<7YIZsWbEUbX?MPG(3Z{18fntNT?_0oeTz`dnG+~Hb zO5BY=5_rdpjSy_){zC^Xdwy@+4DrVcOSB%n)hI+BDq3yn;5!=6iYUJX&bfRM1*k@0 zTiowD5r(~Fg?#Zoj)q#+=(ca-w+avU?T-k@WiaC1X z{Ss7?#FdiiI;QE5u!pn-j!2?BLk1bEyT`P{O`WZ#kjJK>4QO%J(CmR z=?{f)!i9s+yY!2!6%4y=(1!Qzh8D#OL>Nm1Xl4lK#DRSY)JbprgHp!K;`_TFCd%bl zK#Jcyn9b~Y?o^{@tOs-n3>kT=z+XDv>8xcOeqy63G^|Nf+H335eC37IJIMcASFsP(g*Y#DbXU$ z(Ng$oHN+Gky1o^jxYIfS9KXNN`Lcp;1t5joTF#BX0<0bVVmEUn(?C3boaOG;TUT^6Nw6SCSn*3JZ_m{TWOl&M`8Lv02lc>wu=GDO5GMTX zAG3pka|Jz)8uI39YETO(x;!Z(iEB*Y(t44&#k$k2w3N+k3DSSf zaLfB?{xJzFV0?oLF7c2{;2laNXiu@LZ+O~e9*ZX&*XQ&z(MWEOXwg$m60<-u z4!@Kk{Z!@%Bgm1$63J%J4tg;`$V`FJ#^{-qmN~RffHMeHF#b&#mCcgF7yuYIKS~}z ze#B*f{DQI&0t%Ol5#<2JVa##5c^~~K2Hr9hrYJVjI|b_%1h{Y*s^+mO?pfe1WwG^) zv|h`s=FgR;)Jk!w@Me5TUOqSZn11RvkRUVTS?v#!@a|i?9=bJigdA>5D6l zo=KBE41E{cY^<1F5oDaV*0-sUqYhyZ{b~#-Sq|A$psk{CX%o$50v~Rp%uYUI1op?- zw|4yr)8g;Q^-(;PDOMao{Xjq_b3{(wk4#t{seeB0=w}i^mqCe5X%$kp2Ewfw=6B8) zXqo_ZZ$+~af_hJ3lyV!|SSS;ASYLzyRVaEGuBcLZLSU6ne1w!C zoN&2sYsKEub!zGL*l7z6u_1bKRprF~SRJd+v~BadD8ahX@|)D!`Dv5ow$sIX-g#Nl zkmvTce!Dz?(Y*8Oce#)8r?ZX}W3?)4{(s*3*3bTPvQiw9F1FnBAmc+uzzM-BcTag= zr+GlhI^eR%T(Y1h&=tjfJ1Bp=pEArYHeo-48P)+QtGo58^OC=F{%Lsgg0-B+_T_eq zmMruPifpA$0{o`%SJA|;N=`Qje4Lk*-Ck&ie2Q8>WqTbVfJ-+94Nq|bFZDMfY-)D2 zLaL{X;5Pme4=C(BxvWdK_5GU0utDKd)~-7%sckxE9(<@}g7T4OKd&}QWBsJBx!<}s z0-OGV+MHE&<`twjdXC5g#rLZ(fy_zh*FT(*rRzxE#Py!$Q%l;V|NYFx z{fn&FJYJb0~xBIh)`6J5`-$;KN3VO_iZr@ z+(Qiv7>W7|W>q0CpDzy_EqJ!dSqMr53Lqc^MGGw{K>~OkX}kYv@^iXrShit`xLc&h6fUUTgvy9t}E$CbLTQ(kH~8@cAl*8VkBgDF{mWbp~X8(!!h?GUn* zwx=dXSI`@-Gc{~0#=s}k1^K*??d6SOY8Gg?N(8(6i3i6;bHw9NMU$TJ)pLz0iWuUc zF^6j|g#t?-p)D7^^8D9|)i)RVM#Go4yD7@m-rtwj=-^AgL6he5HwZJ_-LuJy6Q&Go zz9MGv_*gMCv!MmDehlY>#3j~}k`q+@k@Wi$m$H0*t^jXYe_j78$DdOzzJ#;W zyuj$lDZ?_={lcg@V>Llc%@yZh_b*_}Khy$A1^qrb6FX6Psh2{oo0mb@-8;0Tze!h$Vl-7&_PjVB++OFpQ65 zWH+W9UhDU^x*R$dK*M8L-hd`!^=FJ1qw7fnz^zXBmlyz@3IP~iBAX!ZBLL;mg_>lZ z%+ic4=4O(R7N8=QkIqqI%l~Uk-xKm*(=a8U@5@#HinR>)Zvvec!R*Bd^@)G*i2tFm z8G=PjDVc*U9lsnG^;)G8(m<|UUcV28W%ljlQ>W#$QdI-A@xS7V()FQTmeKw%(9 zFeP;uRM<-npygta)0j^U7wpvl^~PZ0tI+;|(%TD!JOS~(j%bdoF{=^Z_Y7gUl(-!_ zqhbI7{X9W9YJ6xk!2+GpR0Dt~O5gT1mO&WHElX#Anvo4qn=bc1Vmgz z;~UcqlfYS2dZW=}U%->e(x*aetTTAF?SMt7p?dRa#OsMNWmrSE$w1f!v_s2*4ksGW zhb02mhJZZSm^5h#Q`iZE8Hs|PZH;40edZ0)exHQIxGzC8|v9X+0IK<;QeW+Ra>mLaFK7`VH)8#-sO3k4HFDIpnP zPtD&*rxCHx`c_E7xWEGE@+TBHmzq_Q7}V~O4<9}@KDe^;J=-=( z_-7z%&GS-my??&}|1YC!0gRvDM`W7j{}bCDjV8Jid3W4#+J5xQTTDKC1w6q(t|B~* zIX%`JUV&YvOGlrI^1Pr@fM)CsVtH+`!FhD=0W%)g3pFTG#;Xu|2lhGFDP|dc4d5(B z0@rSj^boqneZSad%l_JDS(2$hZPLKQK=d9wW>k@+Ip6OOZ!~GZTSxNVO00h?49Jyu z{h0E9(9=I56V2@>7s*ZWsg%0EJCv7isL}XD@2!5>&b3oVI_Vcw2l@0OAwtrKA>g3M z(A|hxY_<16=mYJ=8B2n)OMh0@JJGDgkroT2JCD&CXYkckj?9$Y?n}6vyX;4Ok;hm4!l3jFF7FbTM&Vxx|A_?>Hw1BX zH2$XSLQ3;rgd2=^WMce1Zc3h#v&yo*G+>;05~X=IGW-ZcWevXRdvx!3aS<9`r6u>~Oh zW>%HdH0v!J{G=B^Mz7{M4Y-^*uCej(^I4J}R*x31PVzXdbvRd@ws@IWCG5$4C%3k2 zZ*#R))p|l&T_t+9El-K$+W55w!z@X~ndqIdO&1BZVR)_ec(fRqoy6NC=0|2;;<#-^ z7tt>qQt{W)0hu4F^R}tRoJf;KaGR(NEn$79E5U3xM8w|FrpUcK&ZBr>>P5 z5Sz2?yBYq9=&Hj}KF3>wGRz|RIS{y1}Ub8oHnd%w&o^&lLrUnX16*h_P6s?8G z@tJ)_tVpL0q*zbeT0a~K;lzcOXFkwpc>}AQ|!=+qV_1hP$7aE;5`Zk0m8liRJDZwUM7}XUm#8+`s z^gTv*yGgJ&2Ef)@pHR3-N4cGggDTL2Muc>87j#Xe`^EGAdByzSjGPWi1msE%`iTF> zVOvPOx29-N{*@(NO+B3lsJM-0dTe!Q@@kGGG_IPuJOSlrVaR+ddj#JT9^`$K3jrbp zpEfvut%ytZG~`{q(LtN2{wPGTI2TIKHh{bKo~k7|$tp}5@^O_#!9ZQOwJwwFad5WcQo{1sg|vI>C63}Yr|Z%FJDS542r}JlV~os7Id}L54VdY z-XY7z<$X_psf6WRGbm<-tA2;zb6X+n=tLQ^RWI@{CuVPs`fc)$;;w+SS)Uqcmak)D zXr+xQ)ZbW4>s$?bjnH#Ebr@nt*CQ=^biBhok~+ef@jbgUqXfQvaX*Pa)K){8ir+tZ zX=64W{>}9NdBwLBS(CjyKi;0I`cc%m{j8f90#RF>+$Y$k8#Z{0s(y5wf;Et6XmRcj zwGwM~{`6mQ5&W}DkEI+rgisd!s}?{~&fCC~AxG;ad|dTV?8&XvSVdlmHeJp{l2B<0 z&&P-Y+43i>(Rjm+E}9o11C#1;jMPzcj!JFM$<}A+^0JfG8{MmY4gd6qH=Zs=dMmhg z6vUst9WEU8?T&{?sqZCnz4h@Jajbt&9Y2BmMa^p{PTC^Mj_DYI|2P_Hb~#B~RLc=J zJ-BDN#YE;>Ov6X_jC$du%bnLu37-0ahr4YGq-XIJc8eT}gs2x(UCS9EzZzuK03cSX z2BMEOxRN?fh4CSZZ6tf55*v9lNg~F;<%_dJKFyTs^}JNy^1aiBI###E_?t-!rJV6g zAD7aWp+B?*gB{QC2*?eNBAYhF*R58*uyy8#^sMAgS|VeX+yHfsh&Oa{B^3UwaitFQ z7A*y|%`N_jgEO$vmkuFL`>b1*^nH{5AXvp!JWU~hqrajelE=)tEFlVS$a-j6fg9;%1SwVMa=c@)dx}6=PoWJ%`E&YCDA?b0~q`q+aD~O{q=a zI***HQV`(Xo5}f6qn3Qq_4mgdz^M8?T2xZK`}#=DvnMw+jyJ|=qhJ0yVu4xpl$*DM zwHxh9?(Jk!SD#89^%<>pZQhMufYEAF?#Ct zRw1M5rSI`1!)08UeCcv*^*p!XSm1k7exD2wmK&?czDOs32wo$$x2t?SMHXaxi0~v} z$8YnIldH0`_+M8sOH_>>W(mj4{br^@CcNPX87bWs$c%83v1`|EZwD+L%Hec=u(OBy;D-n~|uK9ob(s6lFcHo(=!~ zcT*uoZ-+0BexIN0emk~(bt2e($C_R|DE4i0$APJl?Q%}GQQ4|J@0Z8)Z5HI5lO)NK zuQ1nd10@;_G(OFbnykOKorS<|Xd(OSLb+`P8;*W$EQ^70sPqkN|9emB@Z_3YSH*<|MF07Q5%bYCta+m<^_e&RQ#m+ zR^PdpQAKXAU%CykZH(RhN0OAh0(&giG|si9zD7$+1ZlT`d4#^U8Pw)c|Mbs!4`)5w zh;Xwx7npYW;`uxnX3ln)D`pHuZp#=bVIgjORop&?tRTOu-T@iIr{i5r)mCa2Tood_ zqyP21fKClZBN0hvreRMuOMQLDJ|{g{{7ugG8jDPllRa5C)b7WQB4nhb(Jes>?XP1Y zq86!6wJcK4GxJK8uB0eD>W-Ism@eaU@cU$LaKhqbbI#>;IAh&6-@bIGOU2kHt6rP4 z0|xZ7w?hyKAOvTVpkCY@ADs}sr6(5G(WC-n3Rt#yOz?5iJMVYYZV_{1z>_O1y~;rJ zmn>TbSosbR>r7u@KJ_3D>K`9Is*QA=HIDnd#8%!_6?H6{i~jT^+`2)BfnT6$`L}n} zAVBufJJx*SScq|q;ekzvdUj_c$1V$KQUr5_0)NMc-3olF%&(qE5UVzL_wW_Ri>Umz z%>PV$udlPD{UtI1#J7;oP|ogT#|bLi3m6+jipyJn3W{}-R$tjygaB{!y|%g4-(35#l8nl%Dm%H;o*`FMZa;*v)9hOphFr+m+f@Km zyY5uD#7Euo08PDbsD1xZVz7*Wag-(zb3P?>eAAFyRm+=*S{VN#{~Av{7dS?e<#)IjJUviF(IeLWd1_Q z&GQv3@Rg_F*H~|glY)oYyjp1n%^~-t>c@HSva2m}90XIcdmjy}t%R$mDaCyE+~R0x zvG185pmdITZL$vYwguoE5dgacfdCQ#oZX)r#J-UaFXu%Ex#31_?gXr0ZoG^l${f3~ ztOyi6M57DigyX#+gTiL?T52bBGMFDs#@Edn?GqW(0-}g6v*` zS)8A`hx1J?Tgs5Z(Sy6BJv!GahQ;-W6(3cUY#AtVK*p!KCSC8MDfYv9iEWN4QuVFU z5jIi6q`ZTld(4V*r3Ukm4@FmJ z;c_iDG_F$1KGg%TsCIY-6fCDLqAgMzT7;53+J3cH(qkBDQ zthCmh2-A`t8C*Yvi42vFv=toQ4d)bil>wTd&!V0VtT%pgs?>*(=B2U?7SGzW0>d zR3jYI(|8=vo?K!1yftJ{Y2&Wu=pSA*KfJ)7i#6cI?cPjcYviX6ZAyF1Jegl1KX)l) z=}gUl_0#W@H^M@5CG>jmWyb1b0xhZmbG*Fx4`MI4GH(J(2cE{jF1>h$Ya zrrYRk@7uq>87QHBbFZN`N;y#ODqp-w($A`Im9<5!+Jb|7Q*%}GEa%qX_Q1kbuNAIF zVP8v}pyJH2d(J%PRj#8AX#$6k6PzP)9xi+8dC(D}d%9;8D?NdpqDKgeu>u`L`35%2nxVMu?7Y=tK6W~Q)JdCdFZu8uye$G&M6UbxHJ9!)~ z0w7^|JZPVz5yo*xK&Jj{1KmI-p)icY`ythko;gmp{;HnHbtE<}c%8}Qg@sIJn2%H9 zW9i`432&SvG8)z0G;BKp(jD0StmbQS)WvIv*WZhKbJh1}8Y&K7%Mz!tM_~8A?_;6J z9}p)-qzG8ZNc#*fc_m4U&fW17biAHHLK#23faMYS;_E+NO_|r*SGRvv(%AXq%Fd@< zIoIW$Oy9STHe@`+x1r#&?8!H`-Nq9M_K2+sm$nY-8OQs~A_F_4`M(8q{Q)A+V^m%; za__nF*)vJN)T=3x>2N9}pXCNoYxY`?X&V!H1RsATd1$NJWvV#57pFt^Qe2a|FY_sM zy_MYtX7KK=Q;V^Nl`GXKq)`8PVqUIo3+x2R7iQB#SP!gaz-fL$fslhtGQ2WCzE0KfSUOLSiolZO0p6p}~rXmK+xU zF*je7Id7ixfeqmSZPaX;JwMWJdK4M=ENHSI)6QFrZ?R)Xhwd{4j(y8~6SmPgZBO3i zdiHzpy)1K z@GSG%WSy-|!*;(aIaln*8{a5MN9@qAgethr$uQWSf2at4`u@jJavKWueR%-(ni{WmNYs&?R!8{vEY<{CUfIcZ6tGW;oMbt0g&|K=H_XCUa8@@3@7)3; ztmh1Yh**YDHkoJefh6l=?}!aBuJRFoI1m2K|D0;5Fb)2n_LG+r3ccNw8nTF`rY$|E ztgojb$72zBubZPkBnNL5<)thS;^m#C$ZY;z|i8wKF9fogJ>l zTvVsjc9c*cT(P&WO&(0EagpaTxES(jE@syzN@_rLI+Jh$TUB`u=}tLZsjJ&vaARpd#oa8{m`Ne6ZKH|zlPwkKOI|Xm}IZgF_3a| z7zN^3Tz9w&!p~>2tc%{ieYLI&HKh=m@q$IUedsT09$d}{^y?CaGP0OL`tW6uth`shi zb8)mbTG_B9?ITN6JQVqB*d$<8I+g4JgUvOH2hwGox`%(4;J)aFh_KI#XuXuYPoP?o= znn1-8l$J|)?Z7}1BANQxs|Lbvn`v^vH684YhjOWk4(vB+`>m|(AsRxxEt!w+xBhW9 z=fC+}ZJNYF7GZMR*jfcvEi-T0@^r1avR;4qMgGs2FEnd4vh;LcB)pmYMc_aed?-O7 z$~r`|%nT@LfnI+IDL2KY1(y(EI^6~I|1bIA#+SY)ET6V(Pe0hm@Jg>Vo$otiy}rGm-1`}BWnpKH4DNHe^PJh| za6oeD_wD#wk^$e?j=%d2%_{=AkRNeJ+E;hEf%T%vUxQ*|I-!rNZH0`cvY_2|HOkMH zu8a)i4`d|&I3gT9Kd2dzX>|x`FhV>A`=N$>U+1>ee-}&1RZW4Y+gwDy9ejXo%gp+4 z+&OJJTsN9UX*=B)>SN#ghO{KTd>L^thEd0XG~oOj(*m$pt@o`|52+CEp0pT6qJdVz z)(^aDpu~sZ*L3l~p;};vg?Y?ySAEK$SDdq))1MBR{J|VbY#Gx^xI71eqf8%6A1=#l z*LbOm{_e~=pv$ER;)eg}`tpMlVIHS{wB7Nfc4Aq$BS>G{facdkBYO*KC?5gLC0rBW zvW49FK|a969d1hcOeloKE*z~(e(xT0yUuWWjh*%Dwj0@F;F9S3>XI1}ffnhbwj=Gm z6B>trBI*MNo`;AnQN_)hc(P5;Q9C@Vy#)thU(!>K9zlbte&@p1RN2Vu#R;tTZPOiU z-uX#l6C$SxyEBlZ!y?_TAbrb@ofC0HWAN;1zXB3~V;!1?#3JORR$6;cX72~i`Zpvl zoBC~wzgX-@m-VaFU#aP>wfCJl?MdN(l-FapT)R}KnWlnObLKT%#JB0f4&Fimzg~@= z(J{9N`Pb-Qy&pX+VUrc0+3ikiE*iA9s@{=f=>zT7xR*iwAW|*DXP|b$GUm(7T`{gk zt2-TR3Wh3FxiukUl>%bDsz5xi{(@41x4J@W0v6(i?B^cYC^GsP>YVuiW1)-*o=8;Pz;BFoRgij(!o*8^hpB=bC3) z99^aE?>RtpJH&pR7j)fyd9vUZ@2b6at}w^*#t?fS!CQ1VW>cbfC#$epnbSb>m%*bm z2SMCuA|`zQ$FlWm8vTzJ)D>#cfe&LN;Aiyr@}34TJezp;N=RShM-4-qFCqZO!Psv1 zJWiP;X9U8F2vHz+XMd4J*~$k2apLXEOGiqyB8tSc^PF~t#?|@WIbV3QlT0qQhJy=x zT0DUr2{a7_d`B$m0GttPzkQVYGtfX>_mh{ZEk;)$y6ZCbNQxY#Bz?Yc{2f5mU5MR# z4vysh_0O$r{uel5YeY+hr;vl4Wv;s^EoouWG9NiiGzv(G04NAV zn`61D#1wv9z3$84AJWQTK>%(FeBa3VpmAt5eLQOs?S6ZsEA|IzkPv&?o%jflqpbp} zuqD!0DI-`JN^Fym*7c^kYY61yKV}OkUJ2q%{DP}}tRUcS?}4=oSAu>D3C^P-3+YWO zy@868f%C(E_)AgRzeLHlHkf#DaqI$#PHZ-^kM_8;qFCF z#1m%TvKnx*dB|-$Kb0~F7C-zkheID2gcH-9H5geYASkJ+vjV zjNhWN>1Q{~?VA*{@Tt=LSa~{op%4<0L#H22c;b(!n?^tjN##_L&zt46U4iK(XQ&|* z#dHju`t^Zct6m{^mcTGTJTFaVj$;RfLrjah z7tt;{VW3uKORfP0ftvth3#0mn-Q4DxMK@U={1cA=o``2^Yj>F@Jz5OUr0!1p7t*gm z3fWYTX6lJP=N5{x7`tD6LMNEVq{l%HHW-=Z3B4|=C}R!A_9$u@5*C~oazMcZSpV7J z+$G!Pa%$g1RER?sav>^{%cxi`S)hbPM^?n!a3o#{5fT;pUSNrZ4vk%j&QgCki7(@2 zLdF-5IUaktQVgoH_R(XOA05?im*o!5zpw_hvy4c00SiN2W-TCmoA+NoD;l$_+LUFkU*eJRB4 z!$jKeS-zj?Lr!aPF%(j60GphJcoBlxE^X{{dixbf7P-GaD=?}Ndk5vUVVFl2L2?Wo za-1Er4P(`~7)4Lrye%YhuIrH(Z1I_k8-$b|Ye)6+T~Lk0aqlGzDKdzpy)vv1@lFUq zuUeW+hj&a~YvLYHyLW8T0|~nG*=s>8hyy{3V_^CT$^Q_cM|(K3gE;D@vS(}6}UW$ zTug`Ju8G(8u!(0&*3-n8Z+DftTl|6bb86(z&tWGQ4R!wEEHt8Vx>2}js%8CA0U1TiA*5;rKM-iWMBL8~7EcMP1^E~LTrJU7|s+0&Ljdm-{_ z^eqK=wI2e3aR!xPiezcRFUl1KxT*JK967rvn~=6WnYcIcv5CYWCJNfZSHcT1GL0Zc z3Z_B8fX{J!)Z|NyKH7HB)?7Z&0YJ7qDHy*e1H_v;2!sdo@sZOwBGorf-L;2&cvg}L7k|TN_B?Oal z;~mZ8Nb-?49OGyZWHR)|Z}x+CM_6EM+!o80JQ$Q9BWMp%MEwe%WPxuWNMM39V9SWM zmQ&?fMpo4Gb!v=^BK?+iCXY;PaBfnHhQ)kuZ%exc4^DC;rJJLW-BA(i156p{c=b(? zPO2hQztA)YbE=z8_keAs{=R4G3rnx>PEXR*30bU$M z5n_y2)!9(U+O&^c>TDijzoN`>&nCu=EM>i9A9v$oIZ^DS3v{y8GPb%>;saujn zvc|%1+?oo1)lIheo7AaM!UK-)t%hWPjK=SBd2rYM^b0+MSDMWkm|sbefIBbDMqA`v z;vsvZaGQJ11TH72(3I(*mG;zbcvRw5zQ zl=>`dv%HqiM<5|e)MV~Y@+glgZ<4aO-zU8{^<-BnSOZ$?j zHMQ&moe%_Y#18-1mIEpneiVF+W)+`IaKDZYf3M1GQIZ&xi?nB3cZegDhUjry`jure z;S}Uv(#*9iWF_?Y4qcxwi`qAQDkQ~B+xY9 zeW~3SkOv^DB88q=Wwf0o{3|H0(|8Qcm};v3lKtxf-Wf$C!yLiNHiCv5%_Ieu8+*=P647n|P8%W!N9#hu{F7WGjwn8-xdttR1U< z%Z@(pAx0D7MT#$Qu(~~NL~T%?P#a$gB&-V)z*+wBw*e0ebfIpFFwpuByp2+W@bhVC~9LlL+y{EM0eJ3|qK6<*29eG6kWQ}V2mdCT8IH#V66rmH$;xU;{ z-Teu1U@2d=M3i7xk6vNY?AeO-uYj+=gW0Ny0bl;I_X@3E=&%NV&*UM;6A5g;>M5z+ zusKW~!X6@$?{1Q@0J7t~HL)?`i9}ylDw?=cedBH#bk>l3@*j+kf5!Jj)2!Q|taYu= z$q3AT6m3<1G(w#4NN(eqX>1mr6e)-Q=Y)F*|K)_idy_9L>uySx+77Zy$rj0vc7dRQ zA7@|W`~jrKXR9t=(p?P<7^w^U=x9?Y9mYpGgzgAC^%rO_P9l=WetmU~8W0i0=KSu_ zrbW0jG?=0stnNt&8Kj_!-x|Lmb)OC5ZIz&r{hE?o3Et!U5}HYiWB8izNq{`->0SOs3Mq2cH5x0!OC6Z$MFH zV+YIAb`VvEcRQ}h^dfSLP`qR*SKmWMhEj5O!07-5%+>3QhlQXjQweR|wMTuTUEb2> z+z9M31`i2xRN_%W>DED^fAtp?{NMFQ)(RimW+#S)cplqpw9>)z#VUYoVOa?v)iLc{Bk=;urnlHO#=nRw1OE0hl~o_t{NnNbTeeiKt(yxg*%~ zk1W`W*-6bcmgMl7ginBHmFc^}-BY?^wc$c;6Pm7daRQ@=XTp#241mQU8Txec?IxR? z&|BXsmoYq|@po!wt-Ik}t3&&r+1@H4cYfHYANW~HX4#^7$18_TUaf-6Qzv0w-)|T$ zjTM)?`xhvQy$%>QhAIme8{FTU0E&lv~Opa0<{^SRPe$OV^wSqN6Rlz$z91;J=1Ql z9-X7SUNp0+O0>5(K6`$fS!ODv2R*j7p&Co?tXOxAh%6cYzXMr_V!#teE;NIUE=9aL zm<0gXCRnYgwr3|15q*U+()J&bwMQ*1%+hH)@_ScOn76(k9*brlMCO_2TExTq3DtQ; z38j=z(4%r5pFnd+j%;JVJiN6_Sr@q}H4GN4m@u$lHxXRtp7IB?R5R#s#(54tjn+6^ zGnfQB%NiI-F#vM}Hle&%m-#8Kz##25y?BH4AavXu(+m|Pl<-)Q8>s(PQ1T^4BprRwe4$>mRk3I;Uxw z1&J&nW&V$>QPjovV><5EWH_F;y`G#mk^Y$--^$kWH2+8p-Wnb$bkW8YFcP6SxQ=Jm zKRls*NT+(^2AQbuF2T<#()Np$x^~xS-hSTa5`w?4;Up8_`85ug5s7W4gCXiM0%va* z%kyW7?X5Yp>L7-`P1BCYNgtU!Nn4n+V0}tJRt5oKH?r)+(RZTZ$IAR4xaM1nk>6sp z9DW*@TKaE=MU}dU5>Vee0HHeG9kKDyzD!%p`UmSj0Sk(ltCX)5qGtV8*{2o>aF!9w zLA8BDq3bi=3{z%L3ep=CJ*m~2eql21N**@ql`nwbjE0^=3nhtd+z#rmi;h`4hnbL+ zs}%HC#svdIBl^->>{hDnQ14_P((F)qLpQ;7**GIv`~>8dA7}QaF^cz*q$nxpHTHc= z?ta+X6q@+-Y&P>I-W{!+4B02?AJuBTS5dEyGG&}tUEY67>CI0HGf^Z_BR$}C8a3sl zf2;GB1?!kbTE{GQ1GN#|K~mAMqk{4v7@J=wal9%}C1{mNnzabygBnQ-vEuBs6GHBj zvG><^Y_H|$$F7KXb({9A{bI)Te$_-NP=%>oK&Y!gz7h~&BEfbvTP^nNaCIlS|+iQF6%BK{QbfhqvlET!2I)OO?54OEl6b?t@f7Oj~* zGK^$2J^PAQ8w8SmRJrh}$0tScJqN#arl8!mVpXeZ9%IV~uYtI6jmN$C7WH1Keitjh z;!f!Gd(>nK*xm;Q{q$!ucP7raTTLz8=c(h(vBA#qQkqDAR@FaZQI4%a+^r6Ry~04z z9O8}Um>LC4)g5*cy|$kx9%{e1?l0m z7=G!VA3ZpE5aVls6|8Y!oo%qXbB#Pg*S>+=Eo;iePv^`XyFw+rK z29wAv^(%NLP;0fkNAOX-EPVMZ>zq?fRuq9Z5~p)9gFs*Thm3Jmjr$^oWz|?2SLT@VC(NvMHZ&2Rw*a-qU~!DGosYJl+A9eAga-N6igxU z#XSysN4PM^BJHd`z4eU)iT&>zVyHz=5HN@{)9v>sXW1y1E7mRX*=!C@YshrP`^5vP zQXPS{l7IS4tdm`|frY7BBVclEL-QpoOO9&7V*=D=jK2KE`ot`G)C??mN0(PqUM6*b zRbjO5ppz>S_?Lr%*DzDdf!3S_nvZJQH2*H6c(W2(W!d% zQM}B}Jo9x2q`#|K6T7&eE{@Jv$B{Fe-f;u z-Z?V}A_fcL(xBO98+G;Tco{cyBktu-@T!@ zAskL^X$c+LUdN-w*Klb~VLZu-TcnCOIhL<{dBftS{dj8Se)*#Gl>WKVBm_cg>zQ7_c_~+;)kbvO&choyp5xX4K988i`<1eD3(? zmh^hI3$dNH-+L!uBZagFtSgX6m@?hVJ~Z<-4-V3nJ6oREuCc9i{_gIgB}2?H*iCcW z36Ws8>qMoDFdI7)D-*tC67d`M?pPcT!bkuLXGqV9wY<-li>IU%m1SOZ-;_TId*q!# zrCT2oilPC#7zAwa=L?%uDj#$=oKJeO|NK(=y?zzyec8oV^BHJmda<~OlRFwtdLj`Z z=GruS-n0}Wv(cj|X>wx3vVPnAN_h9$%U&xfTa$z63ecxeA>iHAiP@Y9EB(@^XQ^hF zs}X&dSp-IQ-=90b<&!RAf!!9GZQm(cTi@^VM=u`<36h3Sl821^OYs>Os#B3q)NEQW zD4L0tja@#j+nBS7`+hwM*pt2ZVZ-oLnx&c*rH_Y3+x+&Y)hzSc) z2`GllRyIMcK_S~Jrs9o<$;ZjJYS)HEsljh0(<;Ezn$*YzF&^CDS8<6CR`qs8*DtrTIOZaIn&6n zm-|muvblU%Xgc3vOdQUxlR7GYz>_X**_DHMB9ts|hJW4gv*iBH2QD*J=*wU2PfTB| zH2la&+6W9>M}n^fWWao2G17Yy{9)>JQhW0l;=kW|Tgb3URLn=W^c44%L~qI)y|-TW zi+-vbAPD)oqsWR64HiuriHSp@v8-Yj>nmSAZ8cb&T@NCy;#C$j!fLMp&&rzr)CbC* ziulwLeR3)pa%au=YQ8^7Kr=W_*zCE~eQ@yi%QW$v@)+CDGZ`;FM=P{*=!cnk#*f3c zpSt!rUY8)_Huz|bQ5T;en;pjuvMKIywSxm$(*8j-nd$iSMlB9E4aSWCw_DOv^=W= zLVXu7$ua*PxrKY?fO=UgPR9#TykAK>x=Hk#YLuHyhRiGLo+DZcu;G+ZLa+3HbXI+V zK$8NR1rZ7=#F#`?|BFQ*tLt*HAZps*lr!;yH8<(lw|Mb(A8$Aod;kDQw5)C|XxA%>Q0X>`@n2JbK z1NVhkkvM6u^^sid<|e6_ZzPtIbrny<{tS#PY39Yv&07<6dbtxaA_3YuUKpMirym@C z2bM)3ENLr~Aeop0g4+(q&uD^#=Oe^yf64U9p_j{)AMrCn_W6{G&FgUyrz~Tnnj-6+DkqZ0cyf>-v2({y^!3NI_UCGYC6*`b%!< zC$haCC2r6S#5V(|-!6NJhhG~`zBDjZemQ2_AUkDAW|-B%JWK71E%t?5!*eVYfHsg! zW%+`UZ)_Psfe3|aJo3GEAw!m!~SyJ!AP=^)=YC(wuDD-kY$rh zcR3`QAVYxewo0Y-@~&wEKAXNJ8eFVnLAwi(Md&vYDp_`==Gmfe%`L*y0m*1x-$vzyPR;UQ9672z%8hQkB6vkaQHf95+1MD*||o0RAxjh?8eJ_QxtiEe4h;YK{N_(biBzGi_| z*AVLDHUc|#bR;v@T2{BMvswDY^QN}M9f8CUP&gCW-lHNTzpzqwq(z^dW59&{pXxe; zvFCZ3auRNijl=9Z$nSma0E_*1$8}#5F3dz4-Mg~kdm$jhzl{f%>gGzHO?Xg#Xn#ew z`8{6};Y#*?lnHL97IncgCuO<6*hvNvvt!^E%hOT8322;McBarJwixEe(#Q+qQ*`_Ru{tLwRh3ry`RdOo-=ntw&4 z(Je#NF^mKwSp^$A(mlbCX~L0AX{>qw+gD4%1+3(14|3}#*{2Q?_ShbezSJo>^mr;2 ze7vNdrn0yzZ?b1+Rky~bM)Ty$_jwUFt?SS@Xk*^6;J~rW16G%5OdkhEw2-Y6^JTe1 zF)>XQ6D<0`8sBmr#kl}h?*bjnOkr6dHVQgi*Q0+)fp@p1ymlVR4AR`Agp-`c2sABO-m7=t zY2%}bZ{3VP!&*JwS9L9BRche!U}I}4D=L0k1HQ!+GWelo3NT&w&9!ehTP^a2!XJ4x z10%Iz=Ejwlx-DW!;`)0lP5hE`?(kTy{h<|NHf6wON!@t#?{oT=v+G|oQs9(jG=?9_ z_pb)X`X03u5>47EtYu4Ohtj|Mb()O$5B`QnKNdShj<<#xFXQf8bUasgoc;)xIi<}} zVX-XtCN%i-g0PUAe?%tb+}DQa1J)j_YRI0pkjKjJOW|DP_Jnq&4ya^Y?hrkkpB}uz z`7eg^t@a1FuJnm4P1d>0HC=pTnvd`HBU>COds6EEkr+@aaF(?}Wyrle@?_X;>Sia2 zq`|jE^}kPg@%O#Ev`gK~2YV@!#*{ILUWkuD=&%&iO6u>F0lOH-UiZlc)v6JbU!l#<@Vb7#ab z?li{pjW3vtTWWoY@8{0xMeN-r)`5XPO!=axAVyHpO=x$LjI>& z>qyh5l`G@MPNKfmv)7%=S>QG@e(7AnH%_`f0SL5%(M2MN+ zOt<#MyoVIl7R`V`*9Jd)K9&C_Uw^%sbnijvkEoddQe}E848`o__w&!hTqrfP*ba&`Q02+Pu z%1j$69mfOvlbz~M3{J91q3Mxbs?$H_?Ef2UZ~fQw`$msb64D?Y6Dbkt&bIQPdJkSxuvns+gbvOyPyA4JZ76+E3TQ9B=Bwm=QarZ9|DE_@ODuJvr5 z4lP)+$#p|J*ORo%O~b_7AxPa2JF;twjXOgm0KlEZ8_9PFSyE>T`a!7AtnK7cMUl*l z2g{yiAJ2DkF{24Odem9`a?Qvrxl+S6-FO+>A(uiBoPicL`C$mcRV61qedEj@!^pN!4F^edj0T!JG%h$6Jz3_GZD z)l@MY%j4LH=IY?r-M*?`CZ&&t3}nX^U8zY2?opw$a!=A@pS#?*3#)^1N7V% z4#H+?oO~dWABUS)r|rEYSr~uG2g9?}u%2@VgWlS}mE9g$0DoWMYqX%2^2rasg#Yb* z%aUG%_pDnk&!RdyqZ$QyXVlR4_GlC79!wi}BiZ75L~Ya}hM@3lSfoJ8F7!lc(LUtQ zpDc4Ql|O%v>u^xXqt~Zz;z9%(mt#q?vy()2i@rGqCW)cGeIPb`l9Oic*=fhNc^MHd zb8)baEl&|t$9`|Q`n(R0B1VFY2s7X_YyMG4V?h76yxIuVVH(ZjtY~x-=phIEk@RC# zN6p>fGN&&0n)`vw{Fn7#Zfp9k^iFGUoAxAKF;8pwpUww^gZh&!v@RFu=^gv%zBg_i z@`qelYQ}M}Mg4EAk+~k03GhiAe92#j8U<%u9$}j-qI}NdC}*_x*H9^PD$6@K`Wb-w zYg&R^YDQB3tcxB(?UGll{qv;#KI`^*Q4D^<2^uj`r^-A4w(9)>6V)_&BW$wvTs&$s z5swDvW^13Px15&A>l-q~90`%2DdAq?37mky0yl#*NdIQcP_{wTuD6rg!60Cm`~_u; zQ8jZ6CgJC4+iExzPoA8mFq9|hbeI*4B*)7Qq9N#Y)qhn2vb#u}J>wO|XZYB1A3WDtkeSZTk%vLkY;? zUc9y@ULMG;re6)u#Kg>JsJJw$=Hd?*dqYWimaxTndYSw09VR1ibZeR!08f4=SHATw zCSFBQLpM4Cii8ge59%fXGAzJnq;aGbPyb-1R6J($K?$mq_yqvO%?@_>HcCLQH0l@T zM){+MG_v|;hW%v-CmAYhsaKmqFL3|1_M;0N!n98@=&jt2L>9jHt7M+xS(;~K4bs=5 zUpHBtKNYC%vuN^Nh+ZO2;L1@Q)KABZI49qxnIF@xQG}m z5lbqvOwkW2fR3yGft0?ZzLI~g==Z?HpTG~OKFVf(`Pcu@(^x_2{Mf+3&AEeGafggA zPx!6em_UE*pai@D6Ceho!2~`kPdVdxCA*F0j$t(egVZO{hx12>_U@BSJthJaiYGMp zIwWf_&TF8wMqwN5LPP--6e-Pz#_utKc5?wsPmI22-lWwbJIIqGwN7$EjL?+3FPB?0 zF!F%xUfpS7^uwi>Blnf;bXBTF9?Ma>&&qzol5n>{YFg*}9=G#V>MxYdYLX_RIF|N) zBv8wvcOw6sTCPZSq0Jo)8UaZ3<9@T>PWtoI=z>*Hmh$fzT?Vy3X1QaC3i@lhJl(dYtGYXy3H zi-IkvYUj@c1%r0m_swOHf(h%HQ5iD-TcHlu*C0&I%Rc)%ENx5DWAB&@*;cckisV^( z#!%91D6GF>Kb_t=fGRZ3{TOQ7bJ@8SLN3d*?cx|_l;c-39P6MxW^>W4E`0cTX-^7= z_H{D~B^S(HoK$0CKPmdQ+k~bTU%RzS>(*x_E4zwxpk@&nH;uR_9*8z565JXvYWP^ACnUFOBi5=Ot}l(IHAm|H82tHbbVG3Ao51^n}T5UMXeLz@T46| zxLE5{ZJY%SsaExOZg_b{MU!spxu-j?=Ki>vXB7Uzgm}iq|0MNPhX$Y+gHz=%p=z&+ zbd$qRNXhwc;A$JXcnvJ$M{W0=2sqEE&vTIQ^PRZwqB1?SJ2{mmy^*KwGNfybTo9VJ zH`LAhdD?M)`t$U9Wzk~Q7rh>E?+q?P#c{^s??}^vvHdm4jpoz8CpPwJ+ss~Tp3rDw z;MK0`XIS>+Rc@GvMoaq6XXmZC09|b;4^bWTbIIzWCi=vt3? zTNar%3LswU_m0$GW<gd1an)Xfc<-!sMGf}RoaN1T#HwYja&Z?N%REB z`EX(09LQ+FyAiW0VB-{$_kKsiuR4WrDT4$0Wld50{jIUasLBUJiaJ?UO6BaOtAn0x zJv+IKJiW>#ajB<8M>|uJ$W;cKS4%fasWh>Co(_>yt(ZXNFah2`PVmfn*mk|@$wkBV z5-Vtp8uuh20kY8@ITiit^rU@#`X`ga5^Ky(ft16WTEwbJ%v+Q3D`yR~d5`Bx!h8>S z13j9)o*Z_A@MQaUul@CW8R59Uh8fa`8YO?+r?*RDMBaF95B>|B%{q^EzxW#w3h&<@ ztCW$(Pm~ek&y6o&^#0xfeZ|c0YN@3lGb1uF*9t!TqPX9E8B*fu7;nDPo4uQ(`_PE< zIuKJ)mN}&ubpYjBAXWgNbxuti_<{Q6P@> zWNg!+8WsMkj9U7l)qUKmM87orS|#ANbciMu{>s{-ADEH2c|&cwk*xaHUX{EF`E|w2 z-CwB2BA(afz?8B#)RUqVF6<)bK9eLp!76JZPwetn&sVMdds^PP@3v1(O9@&DAJTX6 zoD40g##ReFkrke~p;3#@B%i>?2_geuj9eQBAYbY8mPo4-Qg_T8(G$nS3?rTYh;QH*l~ z?|M~Os+1%*uE@Xa6ITkiYYklFiR87uCp}ClJ2uZ$3R!S!hh>ExyEag!p zNl|S+8YC<<{Us+SRA^(5TZotx;$YnUIv5=HZlI8taCel;+>>x!U=m0Xi9-$|mUx^& zZS|xWr4{DWe+lBd*O$?I2cGWfGe9&yf%vvM2H8%;{`75I76T>30aDtk!6%B*oZ`C2=om1njZ2bbY!@^F2(9@qkGRCG@FKVRm%s^^x!B$1Ncf(Gh~b_16D2@ zwe^9Sos7&taM2-Bsx))W@8?OQopKc=Klj#SLvBU3QFSidKb&x;s_SZ?+pcI|*DvmU zzZK0)8$^(sy2$ZX3V?uQ2475n$y|vzs@tLa`|}iXWakN{PZ<+TcjM)E-r{_rA;{_r z@hjfXzAqAg37=bP?Q;vPE)BX@%KJCSeYjiYY9q2N;??wkK0S$P3|6JVZw)V!7D_9O zauGG?<1_NY+`|9Rcsz_rqbS2pc!rJFHasUTA#7*sw(L^T>myR8OPS-VU)v zt`TGLfkAJ!+~hInoK-)3bs@8}D;ba%(@H#^pNWB|KnsAh2m==Iv)PbcVExei`>0of zGSDg-_RY=9Z<694FEd&?mizT+z*`G4{OBI^R^y-87j(?DxR9*f1QEX55;5QCuGYg# zC-L!b%vmZ5=S#ZU-x)K!jSyRc%Za=Fzfz+6f2Gv_|5ZwzkGH0l=08{kN(jd4l!h8! zT1_73d}>^bL=@`QCfys!v%L;=SAV_$qCM@JO%Rf>uRy1*qi=G# zy;Q!rnpNxAioyqJ+;j}pHK)x-^$uO3_{ASyK}<;oDcYDV=ByRhTh1!$ujkKwM;TL* z6eA>80`{3Hq|AA8wdjOJ%(b-vF!~sew&)k1M96i82}CPqZXqB`M~+1GtD~m59LtJ! zujj5NkG4#{3YHRNOCN08iU{%!o8$yWtah3dM8Z4wQ}#cKgM~)Nw=WkR!*(>XhK%@2 z6^F3?Fjn0y@I)L=lCe84fZ2IuO+PTPFfrro?tNJFE$>VgBq>5NW>P z*Lp>2_%t&-`Se*|l-zUxuZP(Ab&E=?w~P4(A>4EBZmBZ(fP`>k*}?%ssC^*5a5T#K z=V_5QLYfMcrjp48#t~yK`Y4$Zpcey7>=Gj9KG5W7K8i2`bFPQL@qyQ^A>_-y^4y$8AtPXd)U>c=a!xh#e%3Te41W zCUvC<#BfjN;2tB5gX<(427TTCZZOaWcHJyX;SU`{t(x_<%z*A@Nr6tJT=^ewu`%#> zs4*yu?7xC!V3;!%-8lH!oXQ+Vp*ZfVm z9NXW{&kRCiR2_9=5JEIttc!5)ex%hv!XSx(en+`+=An7ojz_JUutDil(wvNYyD>RG z;TU>ByFJsIo4?jY9q1M>4H`s4w>A3bJ(v7;TJPXS=rfXy zW0h&?K!UDjYn;=!+2P?ENOUy==LSmwy> zYHEjT`V35*E45vpX~zP&=2`&$(z7&Bj1Z8mnb~N8|6KzQheuvwL7x*>$uzs#gYG|u zH1y4j1lN8YbHf>Gf0!&`h%n0h&st8#$l*6TDm+&x>viG1=$n%^txxa&Wpkf360U2F z9kuEi#fZu@>Y6hA2YhW@|g1{0fKi3wuJ!V)D3pf3VHUg%x$+2;pTPtZ&jhg>IA>nb^^)K+Bt{@+7y0Q2N$iu|Bfr z^qX5o`srVgpB(G!CLv=R3AhtC;5lD#tA<+WEWymv?6NH$bHN&$KA)xIU~!a(NKU zBH_^}?L!CnfQ|@ZqZl&@yBAR|*x8N5Pp_VWbhNmxHb`I{5jnc;zX1B>MOt?Kc?>Vu z!Y}TSAQ4jzcoj{vfnaI9)3j#jS{uprZgW#8fjkml+E9ktOaQ$ze$ER>YKrOvW)9WQ zsGr|2(Ldw97!%IIO-hulx}miKirUCy6cMlg{Lfh;3KPJuoWTc-S@}UjdU+Uo@B(mE z!VtSbI)9u^8%+22T?vx+Y7ii%5p#M{eP+!1%-4AlI zg&*1n5N@biu?{WJ2=)zz5iXB26tqLust%?}C=EzUgUPhuemE;Z2C}IuN6oObBf4Qt z@`md2Zpbx>)ttW5d-*X~+8rc%~aJoXVN$g7@!m1^kj$U zr6u@N_W9!Cn4>q@%WyL~lr+-MFcit>#jXCN6bWgV?Y{F%n(Z(70e5g$Gf}8Ea2wD! z#kzCULo<1>`H%NgQp`h^;HEPY2)s0cT)J;pUU>gCU2{osa(tbwDNHx0S~dDp50-Hw zkP|m{VQ>*BqP_l)($~|jGA6M72E#g^sS>|17GS!qGLdiQ|^682YG$x;VXXL&o}k?HDDLB{DtL-Usg%jnjYZ# zUyAZkFOT*nz4#dz*9*KU)Pp{OqQCo>URF_}QWQ7L(GyNMm_wMd<9hlk2eDrEC`~+o zKh24$z&>bv0ShmvCog8kOrv{RW@1RH=~d6+Hw7pXoNs{}6%SZS;e^y3lBTJ*3CwC0 zVGSD?n~eXV!@uM0c>Pb_Y#VBpsa}QtgVQ*>BWf7+(Aw^dX|KbFtLaJ+o+}400);vqqVn>JjZ~tn?;1fg5HttIP}%! z4A3=JHE$#m0^1i*;Fo)q7#a^6KHJ5-Z=CN9Raen?y)-KW-klhQ>_1z3Oue=ORB`JC zq`zDiCEDLz#rP*@{jZ?8zFCDOb-Wt?YG=rCk*jvs4*(j`(KUL?{XD)-SAtZ|J@F43 zMT1NGH*Ay(gKeE`Q|K<6_axWKUWl^!9px1Mcc?0lWS=mx!dO=~v}X>}gr`fjJoxgE z^Ude=AKbcA9>z%~-#l|FKE6~F8baeZ*Qb~G&UxZR$1Kycht3|}0ZSy(J-5!0bj{IG zbA3%u@|`rLbUqO62-*xu`epdCw(u=jkC{-{zp?Of)i!=_RNu9iAOvts5dErBX(1Z`-*Y{ z=K<0B%x?gLlVS=rAT&IVZ{B_}(sm8)Bc~ZEjZZs~(*p3v1=G9#o3B>sLL>?RZ{4VL z)I##mt13C5?>V>7BR2h4h>*)NdaiDg1%Z!7D8Kd!#VD0JPx>%hkDe=f@gxwX(wmmm zbA!#`9gVSWcbfH#?XSQkz{&*jlGQ*L*e+NJGKstZXBM3&ko*V~YW@%bgN|=86&V;o!DS zDu%4m@p5s;Y%TO>cD~t-kk%#~W`xiG`2vddJ1IQ z*+5X*5m1_#@((s8T2_d?7@NzW$w$7mtKC#R!0zLnt3QV2018oyI<+)Yn^)&fs`?ui z0XttUG0!-n`O|_WOSHGw$@Fc%{70fi?kF>D>9AfTQTcg(#Kpf%8hpolGL(CkSXrmNKj$|Z-~n9gw7X1e zx8+Ow5(V%@M>oQoar9`PxI>|{75hJxroaT+xlnKdkaqg=or9X?i`G9D<>_pYzloHT(-WAu^jCd7N5 zKb7YNL5d z&|4Ya*>g@ETkvk|26ELmr9H7vq0TCvW%?)f;xtJm?xV6VZ<+593JM-Tf?io|5Xq|( z$22j0KLCVHO+}arkv&T*(TAS~|Df#Hg ztG>mpcsXfYFMUy2jS8L=M}B#KH(Qev;rL6$NMr7vZYumrs_SEu%ANw1$-6Jy!fCr1 z4+vJY>_4hRY=%S%PfWLjHAP*jv7>P79hFmZ zkp;~z8`ps-I-ae?56<6ne~T7A6XdpL3lZrQ9vA5;pvV=kGKpm}Ho}Yi+0XSKJ%eP6`~LB1dvwdEY8c^pSX3XuQ%rh%*tEFYJ3S62^mr)Q*)~j*Y%GtO z-|nz+v)y|oHE-2XkuR>&P3`>6{x^q0(|_zEuxB!7e?se$*mn4uZ~aKD@R6N2+FGKg zli0^M7I1f+Ip-VScrWWP01vs5U(fsPm&rH7&G2_JCB5)(voNDy9`SaxV*U5Ax_wup z?_T(_$RrLpy(yWO{G-3#O#WDv>Yy@CF)xlTZ9{1lc{c7lnlT8~+{Nr&^Z8+BjR2jy zRdze&;(dr?4kqorNiWFY^SuJgnPUnDrr`=NR-&Q&ELqg7j5d<9Ea*!v8zqrMP4TC^ zR2$9qNg+8di^HZ^0^i3j0%21N9@8apT&~;x>k7#B3$OW75RPt}YQncoVAa1b){Nx9 zM#16edIjvg{?oc>uNAO)oB!^nttLb z)lD*F`o2t~jnVPp*{9zbb)jNnecE7~a9+|Gya(fX9`#ann%83{!gb|G3SthR+2~>7 z5xYp(j(W2gQCrtOGDlr%oLjib{%-NRF1L1|D_1LO&AVa5Xm+J{WAlQiTK4SB&Uxh} zU)CLSFordH0YuV^AN`~#Gm9D%qJai4uBXw6xN&8g2T^ENjMr^RJwQip_o%*g7aln~ zMfdQ@L~$~hHz4jyF|p&1T7lO9@e}k2ekYaxgw{~f2V=;I18ba%ef*@);p5)E9+uZK z_An^dY$OXXdXIu1ga@NGC?}h1J2*>ADtWvr2;>f|W2u)7E z$aV_d!14$;Ze#ER!^o0zNk45a727AOoY}kkieD5Jep8t-5xX4Ly}#AIT3p^=Tcyyx zj=6r`YE$>T6{TGr=@D8sn;>pBBzTclU6Stk5|$?h5>h&q8=d%q6d<>0--tTOH`98) zp9OrMXVqOe;%NXM7+-{(wi*CWyxtdSzBnUMqQ6&vuV?Jv^NS21#ec4WfsRc{=RD9c zkbc%$mo#pctw$ojNnzS*%elknNvK1-w;Yk^9xO9{z8(zLF*2*H-twY`V93 z;M~_()$fvi_Hpea6D&^1zEL~Uzy_sHUMd=u+FD{A@d}@AnZ8;BM%Rz_zo>5CNOlqs z!hVED(a=73OSxz(=NjIm!Q|OETzz)JTyZQm^to}q8}|HryxiUX#_etJ(Nirm`+JdF z^wL8Q@1^%EXm`Z0qgN(EMKGAouzi8Bk}yKJ7RP0Omdx5vG|9t%LM9`EhN9P=7vJ9| z(MCwyDll3?m$c7l|9%UKCX2j{J#$3Q0bDXi;Z?%E$}%(S zJCPDyT3y?{q}U@T*S1!j9@;lCl7aP(wf9iIe4psKbY6MgE|6Db-TKZqb=31+ULb>3 zAVQ1wyR)B-G4C>vQH{lxgKB@CGr<&9LZBlKHh=R zDfWFDV(iDe zXt%C(NjM^oroMCUmHl#xxZq7jJHaK!@Vn7o$(#$_3>)ENwuv`8R9WHe4o$B5AEJ-v zMnPadLs}gl3fR--hNvq^C;Xv}U^FH>1QZUfGFy7ow~^l?T-+Ma0h)tB@#~Adcwcbe z;uAe`KtQ8+>`_)g1G@ReM9hH=c(-%rCmFnEleIjR4)I|4tz*DkiSYivlY4_a3y4XD znde9BW#u$7mb5BD?X+V zu=#Dn=$&nVgo6P@MsQe_JZt1bs`u;=HwI)$CtaR*dTS`eGk(kllE>S?1UIvFxJq5n zmw_0f?&xRie%B%;V=JYa+$z*(->U9J7&&;=5Lw=0NYrxC!~FZm2G zj%77OLRK9`BD`!o3{7aOOp8&U%ZQzCrMn&_uxAm z-We}0TT8S%Kf9h3QLTLPtuK+Ur8nLTUM%yI1oK{sQ74R}RLq_CH?ExVomU z2h=5VFM7iIv(F>}3{ATkVM*(xt%JyyKO>kqb{oUaFN-#^nCv&ximP5|u$>S8atfYO z>Ud%wkB==KLx%r!oLs+=tY;A@A-R?iD{{>TSg?(lk2f!y+Gw1! z?C_nkzK4pqzf_(L9}@NaN}F$z_gFL;jMFd4*>?P<0yIS9kYGHn7qVODzhzRWSxv1Y zqo=U073axCX^AlXQZirBz$NE!8>Z$W9Pbz>K^6W~frjJD7sI5wLPoWi(};Cloqw2n zZ%q2iJ>T?H3857A-jRyp`HTPmi$X_r;fc0bhWdbzwBE^l1!PZE9>FozCVbh%^sdT$HKX%eSZ9DPv4CJaz^%1#39Nlp)XlyG9XqvqZ#l!eKJ*VaN%4k(A211XkQy4 zee3;yGpSF>#cY3zeEQTc$1r@u;EBXhNKm3*@N=`aIV(TAlASyD(iRiLCItOzJ71#~ zj#Kf^6MPwgI?wXxbN+rc^(gD_oKZ)+0q|da3}9vbr0y*RK_%xXoIGXhv$oI{0K7o3 z^sTIDgOp5GUTFn_cLCj4=H5%vX$#!mqwG+uhn@wMB>e=IzUKHW&Zf+Q) z7t}{x^~%I~o>mwhZ46$9TJ)7FsbHSl4)6gdg$4)Z>?B^Em6WHDy`sVz%q~JBQmf6% zrmL2&54u$&FK9kV^L+gA?*}c{Z8JV;cr&b1! zd&V9GAmeLhDb<1i>M88vr=8U7n+r9OOe*=q@IlngFcKkA8VN;QtB|lYW>8rJ?dP3Y zJ?2i5eCx#$ChuuD?X&pzG}#~Jz1XC!2*j8zuCWLI2^^=Ezrt`{vQl&MEJm!-zX^kn z_}|#LY72`xw}V39rPzylfJcU1TDvbax!pPIlwF-bL@qZIY`nTO-J7jxhIU}DuDBQ3+5iaXi_;yc|JMS( z*7UK5Rh#+KoHvd7**C%Lg#ASaJa%noG)?N`)?2;N9wDCiiZu2{?*a8^204AD$ zxkZ(+9_|nMd3#Fz)7*@^7}d~mVdbBwA|C#@DLbAcMt*}a)8($T>NjKChFsy-#q+P4 zGxCF0VIo{xLE+D(9nWOc*I$_w)dc`q9X4L3(+_;JfA{0z3;9)QCF;H!5>+?Z;`=e{ ze2{}1PeXiZeBjK^oOSUsL+cAuiuFaMLk%R+WLxv+=-QIOddY%KT$<62w`j|bCH=5o zKzhD`=6oAdPNuwp*(?d5;-#$0iFKzc|Gw}${7zoQ(r|Fjt#OyISfy6c%+l|f%Zd!2 zo?0+*yy`2e$=egXc4Z4R_g35vxxU4q(dJ;Kc5}sO_f4RB-P4CEzX@q?0UwQFL-+j9 zgE#a$@|YL3bjVfwtZ5|HP-FBRsRi*F+dY%I7129oRsgG=Nc}pmEB437H zDX>~?<4KMw@cZ?GIPfZoc-=q}dp^!w&o&SYISt?bt7zo@NNLs2G>|TvXF@K3E zASQ7+f=Hkk?pFiu7s-Ycyyy_N5j8Jm;}9;Ye5dKw%ftpy{Hiz*nMT5Esx%X3UXN;| z_MV@>H@l=ZpS90Kqk=(=o98l+#@ACHCH$3c%~(efjuDPd~h#9(YRr6Z1_(KWIt7*~~ZXi+HV@G9%^JKe>8Iyp0~OD#O!y zK0GLy=bZw5)V?v8>09tPGd|R+)3wCBHg7KD2N(u3sLtsa`6j^??PYUMX6sI#r`b=W z%Ra|HJR^tqdS_|%3U+S!! z{u5%T;G)tuQ1Pwxf5n&4`dg-^sRreACs>Y;&Rv;WRDxTw(zvMPEz8BDQBm6|iPnGV z%*G|^TldLwSf1y7um2PHBCR~!Os|iHANuvNp-0;7j@LpwNM03|A@%O)=xY-&G(DBXDSmP@7!k;Ju?JUhBr@6dW!?p@z=2-n7x zKasR~bK4PeKE$52otFP_hl{j8{Gg)%y1oPKjx@7Ny04TNy@YQl`>{gh~8T(qth~5c%8_HO?d&c zPd%}^+=eo=Z!-N>o=U|JX)7P6l7iHS^&eTKn{n|R<~(v5TGoC0;aP<)tNcNJwD)`44eo`n ze=wwxY`mv(`OigVD&3MkC(W`Im%~`zf-^DD%zItGco75)+;(lQd0p&bRj_zjwtC2!vAt!->wnF>&xTvLY#KJs|33e1Om2*rTPDfuSlSuj zNk0Q80LOA$K~f*9X%_*u@u^Dn^rA5rO$LkKZ70Bm+s`W$X4KhNn{$RycQy{T{C_}j zXA_9j9~pcRaNd#21OTQFDhNWvbu`n#yOIV%79F$RGVa2JkGK0`(Xs8^05A@DxATlN zWZ(&{&5WzJz1J5Sjpm^IK>KR9_I71M#4k~%z?_!S!pFF6pb+xH=+LHi`y3$6!B?CtD%_rzNgV#{}#tbe>}IJCneJ?&L97(Ei7; zlk(2Hz(Q}Ejd8Xsa%)fVEW+S27j-~O*HR&f;cN^R> z{U94Wa#0P$zt;Z43M@br=Ll*ClLXzRNf{=ZmzN*hvbqmO__ASk1KH8nJk>JKW`FXs zmhsGc@P{vCvuz)BiI-xq7I=KqC_sFML4S)RuO}cC%7n-ZJYmC6`P**pO?^ZM1Bt<^ z0x{d|G22uApxg!ROuIH9dS*uTh^GT=&tj@+FA051?>LUWd3={Vgyj*x1S~bp`_Zpm zZy4}?wqwrc4nw0iO$P&^LweE3+z!{y*Ub83tR^eZGUO-t5_nsA%YGYv=<@OkRf=xS z%v?^IusZx)n`l3{@FXM9PRF;$5|T69@xkX_{Dlvq@uQPU4I8U*QY+c==WL0aJHoUk zWH+lHu-L24F5FT7n_MH2Y&74oCtjz-*kqW8Nw()j*_Mv|E)VE<$(?CApHo=&pJor( z06=Mbt29Yd=6skf$+$V67q}ClG;(Jzi&Cn-U>y)W_5RuU;CCAKL{FOmFViG(yD59# zJPnq!%NniK~Nnq*u`hzP{}@2*J0l?kKx7#ep*g!!NS0DwD9s8vZL6UL~v>F$e;> z8=%7FR(w`f>_^z2i9RJQSiNWy>oZGQA*oKVY7ll#4f*PiFB~hzj ziK>L+R`^f2qhojqNhCaCr)4Im_WEG=mi-X=mreI!o&S~AvTZEE6Np{*)I2xU^kW-I zv{euOkoSpD3*3G5mhqN7LsH$%A=zN8K)cNIW^Sx1C}RW)KC}ku{qXrPpeimPZz6jWx3)O#I@5LEYHB$Ck<;q(Xzh)zdRS{?KV-# z032Db&$AyOR&zSd?+92}j|uc!uiGdhDFmS&a#-qQ8H;NrH65@p`*bvlKL+FC`LTPS zJig>fk88yf>PNY)9NNh(`;*z+DcsO>j9ik+`gUh7NFPH(L?wV6TA${#I|{2O^J??f zq$AfeCWJ^Jid2Bh@lnkFrA+h{X)(txSqti2-^Giz-41MeH2C!PG>v`o{a+opsNPDG zK+$cLl+$Pr4WCD_>N}x$#H|(2(qQE_;X>V5{|(PFa1Z2|cF=3S3MP(I+b<1XZkHZp zp}@S^p3bV-ASq=1mk&0ogJCalW4zZ6(_IbxpCb9deNzjytyrW z_e2(&RoTcX!1u!BE$7X@33<&|MVQbX1rnv-U>?pVRy%=BYrP+$`>5G)ynK#7{_#H& z-u=8#-R12&uPCc>?b&7*=mX=$S4E>*pRdxBo*F&O;}@X2SKB!HdLw*4`uw-aW$2C? zFhtrp^&vGpm7YX=l8hr7?~Uwu2(ckA=2}abe^SZqhh&By z%%dPmFp?63a8XSVB<6rG;gHbfvll(@*muUa)#76hW>(*&(@01=bOE9I3CrDA8UE_9 z)WN`B@v87aX1{lQz)H@hHOoxwoj`J7<<8j~TpX}uZ%wDO%@;KO>s=>tSa|R5o?S;c z`n%`Hw>(H$9_A(Tp%`hu_k=7kF>Lv?pL0JFdTukyAI#u{pR_GrnbxEIWgxlfT2ERK zFjj`k20|2Zslu5JBQ}$<|=p-O^AV4FBw8fJLP?~?IPVs7QrzSU03>tb= zXQ+F#_c>_YJBFs1fDs#>`O2ZPQ6Y17!)1PNL-_SzRz4#gUFza=32}bk?M5S#hBxLh zfs{PkgFcN!twj0{Woe(Kepj%1)>Y6Adl9@pRk?E9?E&2HoS9qA#{3!Uq37IO)NfSu z{yV8K4Z2Ptpb~b3`<`yA9WK9Lx?|r({B6^V`a{wI3)KG_nj4aNqm8eza0_Eww5*SG z6_5_e0}xn-yoa1Fi){86D~=pp7*a@-|Ez!aS$l8TrP5h3hiOaL1WH`{Dp{qQ+yOyH zul3CJ@1|xrpS^fmtuL-O8WjSc)1LP*}f6_UR|161)+g5m&$=AaSw$us=|@ zmLs)Y*se~)f_MJe17q79ypgcye;G8aGlkllXtSpb?>X6Rjl$y%?p6LEPgSgs-CG+#d~F)DCH2t54{JcMcQusuU)e+^M{Og)8V+D7!LUTMnC9 zN6Uu%wVj!SFXT(~{?nJaKD6n%x0?NRSTNB?@Ak;_g!cK~diegwhywN}FQ3bYZtBo% zsiiF?4fQ$<)18_Q(J;H1)<0(xe=SY_&+f7GePt@Av)=`4LFZG=HcA%_|KRo{2`H^Q zAN``Ji;UhyQ$bVAx;liHqJh8Pp&cGEsu6f1VzuNkK(ol>!o?QZte7y`*NrKNp@V-i zoZNgmWz)O&r{e{d#I$7VuKmQ%Wkp?pdry1~E#I6gk-z6ts}kYw!Eia@NI=IxH+n;| zWdxOLuX`8!ErRtC0WsLUczLlgIbnV8mH^`$i{q%ITGr`1?wL=eG6O7*oFMx$gc&HV zOa`~l)uD0!*opARiL?s0EfX{PvqSCAxni^1o&LsE*2X0npCUF~SKE9~vh z?@sNG%5#V_NyxWdC&JaaRceS`ON0_i5_$JG6E(JEZ}LC^e6mR7LJ@ZM*)nDP-u4_G zdWUW?EHzQWAk4fQ0ZiM^7_=H%66U{CWgs36tMu%eX@s5*!F~Qq%H@!J`hoXYzXv(x zuSu*1Ki_WgxY)?9V+D~sx^fYAjIp32bf7qYwPzfr!rJ^f1o}0sVdjH2zQpq*Hov`R zcMyy&&ttA1Qc?2zG-5CT!*VTV5Kj3vVj1l-O!#Af5xgueFr3=S4{Pv=Xm2LX=FJnN zcxm5iQzO_?g#d3(e!Je0 zl(0Y;R=!73X15(xO$;tA&1#1jkCsn=i0}D{sddwFe9%k*D+|!orY|MLg z8}mo`4!81QEagYuA}+L{qyW;1P2$Bky6uJ8E6r5U^o6bOxTlwo=2xdUqszy#m*mba zAQ4*m+fFnD{+#>Y1{}NEJRsN;*bekJOB&FU#NeHRyiSb?Uoqj4^Y(JJ=8L~2A$(1RhF z9kP#-M+H|Sqi+&ykA@k8U=r1kqjfHl4a{20>Ky`S`kYB9}%= zY+Ia&$tt)3nVx5UnabS=>MZBBB*1W2(wrB+BSK*&@*LiM!y;%XBUg8XlYv0q>pq^N z(sPi}@%kn7>Rcj-WZvYz(RIV-gX6^j7vZ3cc-25yAr55Z{}y<%^!goTO7RmX0$&3r zAq*lQgWYSPkcj`(UZ5hc1ySBvbh^&z3{GiyH_$WgNko?Q5bcC=aIjZ-{uJ`Il27ejmh?};+;l(#lD4?qQ#ZSKgHPXp30xW0cSnNX$ zcAv)isDKfJfT7)>r8$LG>`5<4A&#p;L*Nb5Kw?zIk^pPrnbt1&wy9;ltnJw!DaYR8Okst&&hy)6!n* zK2G*+FlC1$Pl*rcZ~rhDcgk+ycol!#CS})no?2ktOOS46nwh7mI&bIUd~5A<&fZJDZ_R4NQKQ9K#pT8pCG3pxRx*5iCwM^58d8R1 zDTzaR{>0!M!V(o{EWos>K3cLkAlFye-(yKVb8I#Ked+zJ?ssm2W8KObZ_(opsKSWK zpDu(?f2Ta{hv*koEVWt227Dr9d6^M#f@2xQa!>bH%>pIAlOEA?%|)buY_*$na4UFMtW@ZZ--T4BzsnwWUHAdDRG-8g%_9q@RDmK0^a=@+-Y-6W0xxdZD*>M)AG zB^;~_{wO+8?}F;gTOEG?+G3&qOZkC3V*5?@!$$Ds=DV$SKT~Zu1=Sjt3XAB=pQPZF z#y@`_#7g1Y>FI4{5uf9<+sQU`aveTGcgfUa7+Xr%Z6l&y+P>_uT5gq@Vgcs#&N$GxKYM#q2 zkMXdKtJKpH{3U8~?PlyIU2QQ`{L;1_oFRY9>tClTv|$t;&)JjXtVPK^LS6U?wVxZh zZ`hNHeCHE8yEBZFtV7@S+>yBn?1Y(SZm)2e^?F`^{su2bQ2zx)_so59qYPV~YOzPR zwJ*M^_Xu;{>}IvoM)RPMxocZ!;<)B9Z(0C5_YIR{=z>X4m9Mu$~Q~%qa9k?@w-?+MK>C96asYKpJ(1lo07#`kP zd=h+C%lz4YwqCnjm;GM*x_I32(}Qkq@l#hyjtZ{FtwuqFAypAPjA{in9QN&hQh2F4AaRS6hOVM-`FMG^X?7ZiG~1GibfumDUvZdaIjj z^)dn(e}H4iD`;C^{c~bz;vfEA9t$&TYp*@A29goK^nh=Y(+aPhvML4pZFWd!_`B=p zsq1gaslRfFog9^3BZC+Ab5D%nU0UN6)yY zo)Ti&fS+@Zeb|M4Y}$G;7#+FzGH0G_ZtLB+;fC?QX+`bl5nt#$=RvIa601_yf8y?b zKCX$_B`}5Ur>jthDPF3*cZx~edj^L*sFBGmUDcv;e)l|&!KAMXbGzgEGBD2cxEh7u zFZYWl_k%0zb?4SWJolltc{Y7S(P={;*<}qG!I%3z8;gQ_e%N5&-?h_wJICn(kLW9T zE@SQ;&pFxt-_t!8T`zB`0lk9`o85PC2^<(cS+62?IODd;mwppUj)<-?rA(uExVEg! zpJXZ4EVjK-3|ebM5alk_NhMr7BS3}dOS#9E_n*6E+Y0D6!`G&7Un8zsqq1{t2*e$i zP~^21alSgN7-YnKLyLnqb6J)MYL>dibz^~WTv-?=)6AHlJ~ws3k^`c7a|k`Y8Bt&M zi}zRG(v;T}2_2tJkmQ#LXVUT@xdPi6T>oEAA0-XthkgOcg@_vkEf+3EGi?nq8{4x!MbD{pdHWmGx6~yntJ9zOG*&mT z^X!^*taCa+%6b!2yu;m}lr>+E0wd|^^TfYrqQ$OGVgAi}znp-0i{bXEP0tc_F_uzyJHDpM(OS{U}dH6GfY9Z0d``Ca;BzkA)#b5-PbQ zZ&P5w!7*fMD%}di5_l_i8!eRNnr{bn*Wr^D3NM9ie@#O4kBEbSA!R?V{@IgRnC)dj zl#Sz7bXFesl_%{_+MMCs;5;$gxU)aUsq@$&xIseh&e~*`OSI3hRjFQXLe|ujoA=zE zOEdQH$&pVk!bOq@yLet=YPA-SJ7tP7B88*V@mLRfbtC;yR)j&eVm9WiD z>%9zZ>BJ(pfCwgkrC=F>&G?XCXHk_`esas<}`1fJcOrKiJ7d$ab6PnhPY);JJ&?GOpYC9%ocf3r3+}xoIfZx zvP=K5B|ppIr0(8YXc@hz(K{ca*+ZwBYFQq(1gpihGj_ssdETHZ?#wIr z1azbt1-$6%9d*42xe6|f@8FX5PG~F zY_&+>#*ijqunA0Hs*%y$DH5^h?$6|lVUr*2VCJr%$@@=C>Hgh*`02mZIwP??fP7t( zG86Jl=XqRkro{O2@#`PcUmK;{5z&dnuQ&Ir@^m>m6U6(Sp7lCF3?yTp*C<4~QUh^n z6Lj*y9a}8>Blj3VO=ZsAZa=lr-GFbxg)U;j+27&;PBU#KJp_};C#vMyEQmQ`dv88I zX9>}trUFqLK^RYNh}3+ZAvg)v$FK#B&Dt1OEu!IQJ0+F4kzDRZ55t*7G=c@97p!@3 zny&0XZqy%?CPX}82GE1&t?ZpdoNX+GTVZ=DnqAnSVKAMPlcT6>bU3TP`!A$%Z95YyZEmVs`^0uxgK%GIJ z=J-O4`H_TRv$>3b3|wF4iw&C{9E|`HZs%7!1I>J(6tv77;{1XmzJW#OZHw9t)c@Qy zmZH!V;>=uASVvhbXP@Pr<1e~-rXIP=(uYl)!ALISgfta67d`t4M^1!6v$6Tb*SuuE z#SaL_U5sf?O%mH@h=@A{iTIUbEN^51NTnJymfuM}y3o*I%P|psv?xS?d9DbmHN%5Q zPKa0~pSlc&H@{R9%IgTbW=T&ABjt)_mRwLG!;WDk|#{U(EaXoQKkVmv7N zgGWlDtR-j^wiNR>Mfz5d<423%-MuTuCM=~vLfK7LFng3gbiOQ(_ zYyx~H87l%`K}kjFU+D?<^GapQN97`y&fBvumUk1RB6VpJ+Fn1lzW!VP3YSs+(Fw-G zpYMW~ZMw(<2Y+)9gm>nDXb&t7W#pzLAk-9YDG9%ki}UXLf9#sBz$ud~s`gyTv?FA= z__Ro`=gLeqlky6XPpPPLZ208|IJ93%2}ucF-9T%-i=CgIOm|=#i?|=Nt``I%0{2zL zbkQmU?CJkwR$J8nxLMS~nh;KF0`~WM)6SFg_)!HK_?(Deo52~?^pRy6x;KN)&(Vt2WmPa2XHz!{q%JnfcF;KM-JX2Zi?{xBUF*mq%*MZB%uIv+5$X{! z=+XfZI32mOn+3ZUNGfw!;Vl^XQZ5qGWoms9lSakQi=n`{=uHHapa-ogIM^Dqd%j;OC0gK9zRh!!Xz87=lEW2 z)F?iH@vxk1r1zZ&eeuOGwPGq=3MIn- z_6LZJ2~3M3HzpVd4M0>6T#i!vhZuBapD~f4@M%n$e@}tv|WM5uzMb;M_VfU0CvF2&&3xrL6KI_ih2+b z+7$Pq=$LJ%&n*EFUKiMcp(Ry9!kQ!+HFLQHS3js}1W2L?8!3Zm4Dv|4XwrrwR!6EM zk86#%5f8I!<=NdeBH}A%;yJvYKb;E#Mp`xpVCT|W~KJqyoIzzu#(qOIqTuJ z&+dx`_B#e?F+l>IhygQYBNohK$!+1&Z=&}PBQu$?Xas+!PPTxeM!0V&`9Y}HX#*9= zIKE6HYl#;?lyPVsrY&LVo5NJLy6e@kohXXIks&*|HWsHpszCFQq)-kCgLz*jP`dKn zL7a1(g(Z{q#ze$Db(^ip|4C4 z@FRz+f-)#HCB};bdCjNF0H0+#1>-~Pc3{a&zIReEi|!pF%+;3&Frkj8E*#VN0+EX@ z_l{Oq`#dYR(8h+lN+M@V{*05CDc_i;b2E$IK*hJD)z(SCtlq+stdy+$N#b+Uh$B-x zwhamub_VKPe(@bsmhUy_zW}8Cj{;b`Q5ToV3-FSFsxV4^p3Juep)+x+Eg2e84whko zpzTqadO`Flh!RMI#{Skd{KfLXuZpK`5wSQ!WV-^!h(%a_&*)U2Vz>b3dB;@JtvD-6 zuU1YF<;W1UeI^99kf<`CoIIK45{riS5IxAPpco@51I#ZK6S|%xTkozmJ<}>UGjKJm zmYyuT1ZEQA>(UQlDkGeNE_x9`WDR=4@^`uw@1lfFH>E@f`+747mw-Oe2F@o+8ir99{V`CKb+CHzdV>+ zM^~`elk{@j5YlsmL>+Z_&dI+XVke$45Cbq3eeCS|2t9D1WGTcXUS1HjilQbDQ^Klt zuJNb9>|c}=VWJ=?;h2eY8H+MGyMR@R6M%K{7vdkCY*j|KuwrWO-13S5f`cYzzUta! zD@ZKKxQQH?T`FxPf+jM7vxiroQl|uJ`Pol+KpxE+`#P^KLOWLcPjrgb5LwFo%(6vFI^ z3aY-91w{Am1qho(?)uw}EzvKVp^5vb|PaRS0 zlX7byx(e9wklL)5H@4VGPWIzS*4ZH_v^gf{yqt@SEW%js-&0JSk3Qw(qQNtTPFK&? z>ov>MI0GkxdF#R=?EFg+yoSLi1P_PW zkT~ivE`JYlm{}Nm5%)Da0d_W*k>=W@sYWW~rw$YG54$1FTMVV^&413s^73Dl8?&?I z?^7$gF|qudu4ueKk@E;Fhz(53Ib1c%Gj1S4!eX@2sK7JIiqNAVqWb3}$MXYU0Plmj zQ-TWbq;M2xW|-da|wF%jAbkv$>rH+nq+0_u5R>g+U1 z2jaAQqSP`d8N6AswG>aDy5OScDx^9hU1Aa3-%H8D-r;XcwqC3c!;%5lFHk(kYNj)R+MxdCe>~_)bE3 zd9&HxM}W{uq9V4_M(B4wRDwAOENK#jvTT{&b65tp^=36J7`)bw#H^|m_A*xVM;2Tw%|o}h9JyOj(Eb>1i#>En!7T+PZz7f+?XZcPNts^J9!wip@53?3oLeu$U)pAA> zii$$RXl?1rM?nIlEC!`^o0r_U4rcPs2*@k9D@HlFkxE$!gJU&8!L0Or4K!*WmO-A zEd8b;RDF}-1<@Y{Wr=W;V^Q77T|KAh_uJ_tJXYrcB;xkvj+@?NcD}WKr_ey>eW3K! zi|j~?_0D|2O%Lw}be#V`FO-hxrX+KVOk{O{#5^5x^kX@?T#;HrV44P(q494_f9zhm zInYkG`goz|WUnuSt(E~(LiHJ@U4C&KD1$$iKQr@|^zc?z`WmX05KM#RkFmm-sU!^J zpn9`$l}P2a8#hz3L#!PoU0Ix^zUllCu|{^O9;k$h_PH@nQBg52j>qb_){uiQQTOK~ zNmK+gv0aunA;?ShD)N42$t0fMH7|=Ko;e2Io0v;f5wYy#5>}nlMo%obNzzFh$A}`e zK))q`(Ji#jS1gJ=#k#^p0J8Ny^5E(F+94jImys@2wawG1&^{cF4$yCPoUOg_5e$<= zR0dDcA6Mdg~h83QNvGBxyx{|<=2hGhq2P{N&>KWgYf`bzQw=)3Zhy? z+6swZyTliR?D7MLXJx=(&4(-O?rFQoN-HP-DhN{W5EQR+7y+JgSb}~=LVmNX!N|y{ zFihTb`=d517gp^Nf+Z&l#cY|VOvn^vew2N2bgfrNYo7!sDG&`sh+}AA>JGRH3VmHv zH%Y_He~V>?uRs`8=uB`)iL#l0PN7o}`HT3BEFqH*vzn-(=aD_b{T4HdkS|n9?#rwf zt}e`v)OOP_HGvO}<-Zm}_?)0giS4~xf}!x|&K+WrT^MeW!Z2$gN~o3`M^5*GFtT1` zovC^f7LKB&Kw?4y4zA_EykVn(4N+zPH3$V!jb)6zh!L`zwvBvVRNsy#x5Uo@!c)UQ zKB4CYY9KbwExtl{oaI=Yx`0f1ir2vwz6Jm`PxmU^pFweGH~a`ZmqAs zWP@>JfJ940OS-NP29;8n8Ah&~xA??^k^Huxt0=SjYzd6*RTS?})H~6iy<^4MpU>sy zLh@8EVyH>@vx&H}-ur&fCgZv^tzi#H!%g-4$b_QgU`K^o+D(JC8G1W1p_;R`HB}gp zOP7G}g=F^5N|F&{ARDLB-I&AzxoICB07nlol)(&6N~!tC)^l41Gk??`SE7zBY_k1Hm@OucqU0^Sb^0) zALagV)A~4a%0;5KmdJ|yuZQSwR&mAH>^Ob7^vRA2L0l$VU-W)#juvyM|DC$Q>vpdb z8K_`9n91n8UMam5xUUa;q~Tqi4@C9x)*(V-)fr|mj=}>f?~HmKk@pTi{Zc|}CvnG5 z`t@DWn{~f9tvS=F+N{j9qs=MsCq!Qa-}a6EjYsx7PLtp%L}5};c;RD(HmIL!76lY zmx2j;TuH|(cWe>mmYb+i+J;RTTe!H8oW#4M;>z4M_V{o^oF;^HM+lbw){dll%6i&(U91-fZvU2(+e%@po>#n3KiH`qQW1Gc4bh691m>;T&)l`@j{EYX1!^)jh%IqQZ&EhPje`Hs1 zHQgD=#8)gX1NY!{$?@{^EbV0HDID!L9*!PnWZ7t>v;C7@*u8lQdybs%6lq`p)P zFk`Eig`yx+P@1mUt&UwvCX6e$kRYCU7enAoQfb9c{^mk?R=*`VYP%Mv-5oKucs!q} zc^d1d4|}@wcR%uT@E%bPAj`9nH}$p^skDR1tLm!S_&JR@O7Jji5M4D2-`4;p1*w^s zdaM~_w0<|eJY%o3IozMM`xl{Eherb5e%5S?=mk`-cRx)9Tso5Vcbq@Nk0%@NCJF0H z_j<8b`AitoH47rr?E2h?AKGcJ8EN_SDNa!NG~aNe?Btr*DfD}=iw`-vIhB*ejg;`& z#+4;<#9$}+SiHgErNb=BfDYUptSE->jX5RmT zP{QUDU-Hf#*Zs_y7S1)qdaoS(HS&VVpaKwAh{9#YaL^!sS;;&$IFD`atw7 zmF&4D+zn_$a5r4(veR|yI}3UmPwQoyoWh;pI4&DCkCY8le2b?QIuQ=~i6|vaYyrXvg(o8pbagXBN!fX0wIgch$$O0L0e9 zI&4Z;j|7P-KQ`+J$!*3Qr}8CWZ<&HwBii5uf+iKNZjuI?22f6v2S12;u1^L`OzUIAU)}OGzMk9GGCu&v}CLQX}WOx;F3>0vPQLw=RyOPXQka|e95H$Uk4-k zUzd*A*@!-uK`Qet7Sr^HIn}wUfA82R^(|>*Dq|{gOmaNUsYZf?|_TbiFbO_W!k4}8S*vmd(rTrs#>O=di3e#tFkhaF!>BJG#%lCz($=I ztp25s4+&00D9?@c5pH2@VXYF zN?gTU;1cBOfWXZUQFLD!V9AagGZ_zF{dU|b)FLLth%1vQp z6-{1aq;Ks2@)m60jRb_a);k~blXK12$dd3X39^X=TX_|1og);c8$Pl)I!=i3L_dge zwwJoeoTANC^bT`YZM#xU`H=0&wIG!?q@sM_pk0bF4tG0oRj5Gp`7;Yq8%*O9+@&Q1 zrIf-9^HTfSnKNbgbQfmxjZu@R(UU78x%EHoNm&k8Jg#|iQ+kZDTd`U zVF}+$vxKz`NN&9w`O;FIeuE{GI;z4*8pvyx@B% z=NKMZ`B!&JLuYHvLcnTNmGXArMY<`To67^B4WW(#X333<%Fbo^6KEcRQb^vgu@ z4^4(C$)LI-jBlEWgv|APM+QbR?~WGOqm>QsV$53C&u0647+ieRtvg!GjZ9k&y24U8 zNpUC@DqISZeZU-6jj3*C2BFM=@%gN_NV<+7xd4L0o2=~H?lh|wAay%XXs|tCodRnF zjZ~Cej)+B_LPs|W8h%5Cj3vy!{(Sb+)TY{QeUCQMtyDzIAQwUerpyqc*5$7%F)IJ9 z*bx_;*D4%(@Unq@7b@T@irmRPwpCHW9CTerj@3IW-kSikZWTQRlP&;0EJW`jGV zE(phjgZAuVLBNtsKI36NU^E~}UtHVHRWw%>SF>C7v~16)Qk)`J z#r0KLS9By~PZ5^5C=A7&x~&m0wpoG`?L~7MOMHklI%0awDFHI-6eIY9Ks-jzVqS@~kdXE)D>XP> z_Lmc=Vka0(GauDl!*5d(2+`^2rUrTEmNl4w`e?TyDT}Z~N^wVo!I+{p;`KrVv-hs@ zwoHwyM)H5kc*KY@&nQhoU3C+zW!fQlvD1uCD8JogdwQl+p&{>y_#w@Q@jj-% zf%b!a{l7}dF#v@9$|tsqzka1qLPK>{7e9J%3Ig|YJc$S>{)yY$ykYvPyuqf2xB3~T zZSqp;p_nZC_~?Wem)P$B09ErBAG(&YXid(FA!D$-mVH{eN7*(pBL+|fPK^+qu!fGO z7u(Iuo|t}=$9bvF@@DfT8)waRe93!H!M%Dh<#@DbnwgfKBqdX@p3eiqzPJ;j>yw>L zxDtZUl?sb^Os1VhKUeHcdZMnmtK1;On4k$dCfk{UDCDGYz6q6(;7`b^{oHn9b@ag$ zAx12s(`J6q=;iLy4G9d9MRy>Jx zJy`xm)^}eL^G@Ar-H&3-PTc{8ZVa-|HRM)XjPGi}n;tKW9Q%|ZE}ZB7`H?s*4Ov#e0OoaiHf%s!HO+zW#cJ)Qb*jG^2K{P zDJb;>{60;3(sm^Q8toO6N~g0_&SODotIr6G^QvDIo@8OC?i3Y#wk`Z+=gPHKSYnIL zD8zqwvmV&(69pogb$(KiN%w;Z9B`}Rt+QBh#y|Y)DkgUj6~=aolar%i#-6s%VJyHe z*Ja#aZa-<(V4`<}4b+w-!FSKtW1&8^OoYV=+VncYWuTe{em8bG!I9`9gB(Z9BsFIi zzA?o=9|p_b9XDrOzh=*?EOaeLKO8)1fHNuUvo-sr{sfUDLPkkHA5Ta~7X2hZ2Y;Ad zn*vJrpY-OWryJQUPIQ#GDp|pIZY5DVaHZl{qw5dbJZ0IHgh)dL(pC%#>;>1L3^v57 zEGb4~rUrFrF0C&)m|@ndN+HSsBTvkTovwxV@1cBu!0_rk;KXKUn*+aK`hvmXArq6`(MU83QQDtdTMvOxG+wifjHO~#uvz6SQ=)5k1bDNm!L89dd?9;;axF zj}XPYc372XT8wmXjj1$xq7q%FX%FTXmkm^y>A$HpODoX0E-<+9ko?{q2vXpvCKM3> zZn3K_jaGKzNs|f#E&r1jkE5(X8>k6nsK3Zg&djPSS()`pHPiCeVqzBT0-MrE{DK9&y9PBxYsYGi=eSj4 zby~-{AgE(8SdNA1r(j_2Mk=@fq`m0&Cm|hN7cWTH9cplyiBNbVy?!E=T3ilGG$=Mo z4g5QnB`(9_DLMCCm!W{=+i(O;x;>}dmGfML=fsG3kZ+DM4%~`{=#d&|{&w0)g)Dwa{~StQ2=VDnS}CTuFqdxfVhU@MjXDUT-kk^8&eVxlBaZxa9QGHVoxuj^ujy_Ei1rhrZceEGI?J0GSV$|(< zcYAYTtwQ+SZ{VitGi9`%2x<^UFokB?xJWZMzF2;-?SgzIg+fytnzU1v2G=wN;*@I| zA51~MwxSEE6^AV~f{zQA-AV$hg&Zm4Ej6c zvd>)phe}heL>LBqPx6h@NfTD?y{*B(mZj^SApQ3cY1vjgq8!^$o9B|&pQPCgo$~{G zKA)^P4lO7bR zT|KL9PhQPc4-v6<0pOQ&$FV5Dscx+2S(&&OsG;P#pZl6=WNBRSY<_+t5M^Ev{2@9n zuG;-Vc|0^>griEnR}U%SO3_ue=9l0yRa*e2G+Xq^9pi&5?JRiA6BT-|oqnMl>kBY> zCrb&G@u6_eS=Dxy(RaR0=~UDb{{^%rE57)cfG$F5cn3@dq@CoL z7O`0emJ)Ar=*^;xtbM&6NX0P%%2hATcH^E_$_C?vjDp+elOy)_9P6plWoMZ*gwV>n z5JW)!*0XFEn12=_DU2gKX$RZymH{HbMNZVX8HQ*GgBL#~uRRh`NO4|4KN@fvTgs9E z<5`}uyhd)v2}a3BaNR+x(dCtEm&Z>NNfnb1EI^?vj!6j-Rs{6K;H~@1xt?=zdiZ|+ z6OWZYVDutFyONGr%gpuzVDiQK4~f%<$V;PYgM32Ikk>jtCL)ZWnCg*YLGo>uZ`?S5 zZ&D>iCN*zhqSP&t@U6z<)8vHb@esnPT0i~}6fx>(40%HK*14kP6G%i~EgQ5z*eO?K zJfwDo2woX;R48j!KWcI(Z(G_Lx>06b3$H51Ol}X~i=?4!0!KmWOZ=#ov?K2+9#gYi z$|(^?M`uh#7$b{Sa$Ru1K8SB2XG6trHi)0A%K=Ks!wlcyKa1O?@j9FwZk(pV{TF(8 ztv4Z6vKU9P$bo%5&s#U^|Ca+g;ris)5ga7VaMvCF9(rtck*@z_hiAL z>H%!fh|d`WuMp<_k<3EpewNG|G?ck%srUnyGsZG=w& z9%ktN*VGk51HbzmjM(VMEmRCK>tXQJnVgMv3rH4rcR;czPBT%hOY9VSZWc^HGgbFS zOVnY|&$89V$%SCdC2T$xwDa!DqdUhJEQ-62I6>^O%b>@0$F)FQP2F#>Ld9Z?0X!O0QP0-l2JEegI1^KtE+U!owCbKRLHYWv6t+(u$k z3U9okn$l5xM+o{O#5spTlYztQ<}Pb)O3eCP(L1B8jH}RrmnBrS2cVf+Zho88nK>`2 z8PxyIhTw$07!PST?x$#3fGvMqgU-xeJt=| z4M}!tWBR!y=iNEyBQa!*sT}`vp@gS53f*b>JxzPjM5o{<2)y`F@WlZY%&3hh&Bj{( z)bVMk_%-u+;`!yQ_%-<8t?K&kdQ)<}l)I>R!Yr|Cg%-q#e%T?Kr>#9aJa>1KUW$zf zAM}oo5Li{PMuVzDRopxIxiu^GEL@Vhav;EJzHd{_M;op>8Gy{9dBy(^+LNa3rVhV_ z)D&&EJW0Qstrv|V$@cq17x%u%#`9yLbR$QlUf&RjHavB3sKSlt7<`ELJbJn@%D^~b zNL|j_yTvf}y?I~oL3xGuRz4|PV)C9$-G6Fm?WCcvU=@?coxowD`qDe&j|9-n%p8mX z1;Z`bC{HYg(J$Glw95xw6%KJ;+^JCTA4Qc)U(T@J{==AQDJQHcokxbMa=WR>StARm zW3?Ha8825$l=dyKmVJ}^Lpx{ASWW`2CY~5A3Z#aSJ2cQto-is)@J4f>0&R4*vit5v zoRND$k$M6wTvx7d2l|{RP8OQolmik3u9522DP_}VQ6ZauE)QYoJ0?+YOJz~j$k#?3 zcO^!`QZTg6AzE4b=sPKOm=9!y`FPE;-x1S%nr{xcdM;G>5@-bn=c1tVeIg&Zq`8-Q z2GRvpmX0z2(!c+0pxuSuM&O=C*Pwg#qR~vt6)Zx6 zF0-JJ#5D|Y1Donaj&p;jh19hY4tI<7Pgupp%Wbc>@1)}2{vLp{iWcLFv$KTIXDV^z#+9+nrZz5Z_e><`|0@ah8Dq^Y1n z!6|$Fl4G5Hm!bu+V8h25tk30wm4rF{9G3yR6_=82&U?a()0UDi2@B+0(g7azOkSn9 z@O~XArDsTuDNiHQF25$Mq|R>GB(u)6@1Vn#*4S^EnAl%?#c{5VN!zm5?y2w_pT1jJ zFd<0`XilS|O*_iUAj~&0`{v-bLYH$+JfEJ=SsInOQ4@m=$2``2WH8L2*jfTQ|0`gO zaK2c2CCW8MTzZyK+BFTVY0!SqsuN|frGwP(YiEUGzcQiE(Q;pGYM8X&(Qaj5N3Ii` zX5`3$s$f!{-?kvESRWJLhqI#1Rn|*6$Fo!MQXJAg4QgSX(fL3Sx%+8=&{-R#X5$A z+|`COb7aZ?Fc3{nBjkRPoTg0Rlb(u`J0IzJVRRR*vfQ}V?urgU=MuF!?*Cq?QCQa1;^O#%q$N`qd>lpc-QW-69d`Yb{Wh== zU#WY78|KW%drNFo4`n;`fRu7P>Ncv!l0(H?vDu|{-CjKR^Yq5h{H*<3+-VaUCz_PQ zX1(S&TC{FUg0aN(fJFf!Gp0mJF6b;x5fASV?kxfsW^%SCC4*7p>g zzR^Jje=?E_KGCULxs__IUWWf#TvgUpQx2kfJ5|w&|8G)!^MCE=28|nSfqNb9AA`f6 z_XM18$|zearwc{?Y}?XG{0T&-xovy;h3{1om=Sf_xZZKr`|p>~!$uR?_N}olp0q;-;*!1}_ZWENikrL5MI>qJVVE^K>g>*uqT+tc;SML0}4pZCh>MUVsJh#}t4 zjUBvDE8?+|O7Sum|EbMsg-amM&rXDIwMp+kbdb<2< zsm|>2rqK@evxZEGxc{BMQqk)T=(kc~CGLI*IIB3s3X<{c5?Eg?WSq^5YHrcw0Jb`- zs3CcP;PIxG0>NxzIUSD*DCli0wg2KK5HCw%cKLo`sFpJ3w7^#l@_9-1N|WR8EmjY% z>#vm$MIp~6-bP3c3GY7oS&()^q0&-E#z8%5CmB)`p3@o-Q~tfN`=@dTO}!aOB?O zfA<$4?)ac0YL6X=|xbzTxJ zBf@~liqIv!*a8AUfuXd%5HM4gOSd8nRjr%3dcI|dU!b` zlLU=L&Qb!t`U0Ou$b!hso|F&lb~Px41W*MQr8SeH@^y*-n<>1AF^WD<= z$GA~~BgHM)kNiUSP-}ZJXHM3r;w=rJ?gp0FSyhiZHG6z<<9TJQp#a9R*fF$Baf)E6 zZY|&r_|Vw-TGGxTa9^)avGLu*Gpq8Fgjt&ki8_;@)YOMVD@JqgM#+@HWppR{pY+tGltF< zTr~Ec^_|=5;(Eqrhu`uvjVH;OY6Ru(eB0mCr74NaKIe0xf1xtqU)=d(J{N*tY4Rq& zXyXTwrF}jZ;t;WTxI-@r2oOlyn=uXAe)BTn&nznOo!U(XYcm3tE!t2}aY_wMIQ)@A6w@u|KddsaHH z$KP5~_G^pVM(=z&Lq}6R>N}5Hs0@ljmTr$7>B2XRwIwFh2j^w`&!-m4=B7XOaQ_*^ z&1vkMd%#T}#uweLOqKCJj(QW0;B!}|IL%>63m^dKmYX!9#pqjipfqeW{aVw1C*5*= zT{skz7(Z$0y$e-%$m6Sop*JjBY8&j8-+b!v4}FOI=waDPB!a4uC-{_%{6XA%5E0)a zKQU{e-z=XBm#CGV zrnpo4-)hL^(QP#g*4by)+7D6N0P94|n9?Vcg|y$+&MpIuAy-Qhbe8SzFIOJF+NUot zd%Px`?fw?@O^aXUq6$3RYe{~_cD1S3;Xm`kS-_E!QH1M$|Gd0^E$7|$1O|muQa?Kl zd#MQ}oie_A`1@x5ca!Tw!$~(Sp?=@Y)k7$Ly2sRfi2o&Cs$XN>2kpl2=Rp(9${8h& zEDy_nW8FOstL>%lxRj?Hx2$J!#ot|DfpJN5K~@y^SaQc>deuu5rs0y<^YxCpzk@lx)~x@h))75Fwnf7QTE>^C9%#lAp*=UG`evz`NBRE{V!8nj^mRwrvB2;0~x2Z>nNVz zN8I6C2HL?3ccJl`a+x-oAg9l2hu`6dU!OYOwInzpV?_4R2HHO6Qn3GQb(m_d{dR@B z-V3zb;2&One7Dp6RB@GU2hWI@#|E7KqOtKfsT=Qf?bjpSZRwzq)X^llCAdzd>G67- zP|NkPfs&#(J;Rhz$K6r0Y6?l{hc1JmTM1-}y4;+h2Y3uYv8I(b}w79zh=7g#cLm!(I zvgy8r7}uVwmXe8=Iigk6IIx`}v*aWQ)TjT@pTC&Ss_nns@y?&JcdlprP%AYeouALuVN5(U$ClvzbUejO2T@}GW_Z~Wv66) zcwsKG;nOG$AHV+TjVViCcoDt_jbzXeTvMn!!TQH=_$;6$G3 za=%0V*%kgRZi$$>mLoelHW@1LpWh=#c z`so+R!0)9Gn-7oR2=qyF??>quTL}os^J{n5_X<%-t^wAP*T*NLwkxhsl{)^1g$2f(twK$L)#rf_x?crdFRk)$6=3?6>!*#IN7@Z4xu1b-ijI;8E&X9yp zZpQU))D#SU@#1LEkYGySB;27Am;@9I6Ipj&USNLM==s~-C^gnM7K?uF_xu}A`syff z!BKm&@BacwK)1hh*W(XwzWIuWZ}`?qJW6uM_FcHV@%l?wE?GWz+RVwKP|_Cz;vaEn zo+cR#yBZqfH=(g)D*|r_@srzyr>#o!fi6hf}9b?)`PA@WEVmZ0~Kq zPML;v>yg?z*YoXNt+{LWu)^IIOsJvV;R4yKF?N+x5vqyEOaN6#3>uvfBu7p}d<)W~Z)DHIxIN`-hU zeJDUJFOik3;n0`_d5Iy9565B!Mr}v><54>NC44L_YE9F3$IhKwM*h{IX<@u|`*vK+ z5~lhA{h~q5_{{Z%433u)q>oHEEauj&z2CDMkvSf(8;Qwbohz z{CRk$XCP7(O7z_FIMLM!5BKNKpFj7A@c6~^_dLmHI-h0+Q-CA=+2~wtNYeR5m*3Ow zUEt5Odb;PNdW$@_Bc7c-M&}1dcT7C$8gf7Y1Z+KCy@;_t+ESHncJJjXK>mAdtK}F=?XF>g zJx*Jjpfq6ATGWyIiwXqozKHy+bG_+t(co9M=K>a0!j=(^}f zP$W3z@gPZ*LTntJ$k3;1v2wyatOa%3o}S>19qcMqS_Qv&B)4S_Dk^}0&&%cus=wOl zak-M-aYr6JM|H7MAvI|j9N9^pS8M2v)Osm>wr!Diii{NI>`M%86$VyqTW{M++sv2M zC!e05Zp$2*FDu^5($jXSBrGjpf>M4C(tOGJbvyOq?D^?<2~HLLdv#yc=3&k4zns8! zY=x%oTG274Hw~)rjr8$vU0@5pfpGqKuo9WQX@|SXqqS(*okhONC>uX-v<#9bJC1a$ zw7(ij`16<`dW*s_p@la_&rZ^_LG~WuDCEmKPePS)O=Ff%Trd|4mF@e)=tL~->z~vC z;o5!>4tuamyrR^~;JtBV-q1gxr^^Ad>3b{g+9a%iiVn>hu=$!WCE``wn{l5rg|t#H z9bGU0x!Zmi3cEA8hV5s>le_<`eYI9TFr}NiV&&?c=_lgJ?5)HKX|fVSkKiy7JFHzB zv2>)+E9%vn6_Xb&lb0-Ti^R8K6Lu^k(GpAi=3Xg%t6 zCZQZ!tC*>5JqYBICB0oaFsV(33%*^_9gPw0JB_&UZ{7>|_fkr&bn)E*m?-W&i_(h2 zbNY!(7CLhP|DIm7T=T}2=+=BnPEgnKK2P5RK7sIOOLl0%Ic2n-eojy7V5 zXoOjj=Q@j=;C|-rFZWd%+30zrX+VvR)1B`ci+uc?(bV5GNlBwen=!vL1hBUJ`Z5VM(Jlea~U;Z`1 zO8Hj`%S7U3^Z9kZiF7AnjIWJWaD@BD&s?+IM zyH&i0l}Z#V0duKXELYaw6>+LnD#c z$}ghCf?p6Y`c`6EbjEzf=Pd2Il3`ya!mvnG0rqi}2!(CQes%p(CuoID~ z+4?f3RVZhy!x|a6tZZtdPNt;90*3sisZ?#r1hro0l;0xw)a7sBDXf&Kh!S*fkr8(% zd#TbltB2Np@r8L=`4v)`Oe|xyFxdbPV?|4OibPmAVRh0=GvpGPh6s5@BSZFMN!~ zA(v|Ubt87u2*5rje7cbEM)%^GQ1uue#sE3N%BoK$RS)RL9L5nryc+I(pEg?mJq&I| zvjMb1&!-%VMm4^WpIR@Xt$yPOzgi?4?2Eg&4>8E@wVOuu%DOo-gy-b<_>)KZVhWt` zF@V}v$|dT4!)U*86n0>Y8;BviE4qu$R%>N^)R#$B6ZUhtLOPQ65y5(UC++8qYCAMQ zYoQ9O&h)@?UNga%3pn>8@(*X^3f0i|xs6T@Nic^ok9lW=m{TL457LZ3s zePRSUk+NAzzQ4fofK6HuNSk+WJZbtQyqVBJa7e{Cv`W~v(i!Rb+*lPbyC!9V@uE-fv>qx3 zaJhLm1bhlPI4XsjYa<>!$P7$tEd$DAw0Q5Ahq&!YWqp<*jYQVx-bjK!>@f-Awj6bg zYc79#c<#C#ySxEMy0#La2OqUTDwgyO&lzYp;e_#UE%xMh#7^1JuLRx)IP(54k1F{3 z$~l5WK9IrX-wur{OTjB=$2iO0WnqoRFG4M;cQ&AQg(*<{JsuO6B=!eo`|0R+dj(aEJc zx{4>iXDdmIB(P)c>#4($Cp`$8zVfMe>rp5ZDduy~3yGf&B?Fb-nG9b3_R!R2n{S)c zt(I#L2Apw3yczUDB*Si;K1`3DN8z1^VdVEuX$@vMg(G2T>Di7y{lp8LF&i&J;m9HM zO6i74HgU;x`J*AngpL|ZdbaB$2wXLAlq$n}=2|?LE2@T|N>HN8W@xgmkVBCrTM~*vef;j|7X;`6= zd+KSE=1Q#w<2wm1e=!WHDUTQYR=A~rU{AE5r&ub$1PO0_Xc6zC-IYv;iCy@K3>Nce z6M~pF9iuzqh2i0Ty&e{45L%5d#^Uw^jdD~0bBVZ{?S`P|5dPM}+_@3*-i|j z!Uz>d$$%#nWpL!{3xI0@02Iax1J-uRSrpp0htFZ1g0`e03E;@KjS|y2Osd9y!?rYU zGLu{>c=*<&BCOEQHrY7qS`fdG9>%{|P`r@rEZu3n-jP$PRdIunki8$x8x8VX^^ZP+ z<|QOSACIf}6w%yO(pAcLhtDloMllB{RkV=&@?a&NtA97XIkGru9Kel|mMSH3vszEn zxX+ZTRop=_;;mu}+Qi&|JZ*c?w%3B{uX9TE&zF$qfG0l_r!2oqb5PW#`;j-cm(0bg zwGw{;4J*Y^!ab4?&kS6m=995(Xw&^Ll=K$MIjUhg+Mo<4yons4gnRi;%^e0eH@r_W z`53*W512;91So{m%{kMlZm0a?liIrbp)Yoj&c|Eimxhw=WG*r@xFi)5YyyWnvb*bF ziDyHVYLRvp_r+u;3|%D)F}m}>s_i$ZuqZ#I}}N!AXaa%Ap)TJ>UyP|{)oRYEr&fX?hJ_V&UNI)vx2K4 z>TvgNiVhsNTIrM2Cjj!H-Y(2((OIq?$_kKQty~mIq(tE_I?%_y7lp7t#RGk;mQdv^ z^7KiSrptfPsE!)PnZw7WsGG|n%OlgJA2&+LWhd1%U6jBE7KYJ%B3bwS_@p`<(^lK` z;w%!<)+L9M7Daj6!nDc#eSU&lYCz* zRUorSdg=nF!t)i&x`Xc)SS&I~m*tb8ELYAd5AOJtDs=k-;W^detB^<{X~M)jy*FFu zZvE;3%hZ4OWA@ysq;uGqpK-cKjox0xpl#@fZ?oDWnMCDiOj@FOn7-B^5=G$|S!)7H zgV`+8Kr6xjCvu=V{`j&qU!~{>KcA()xegXG<#-@Z`D;yzu|;$1-dilGEW_H-3q}e@ z0%l4@in!$f^D5*ja7=>gYNHI$A(In}_+qXCObR)$u!CbnSmb4s%{?FJHG+5JUoi`N zf}Tq0nDW8Z42weeF{v~{djoMx1ajF#5vU*Dl4q~mN!?aZ|nOpt&8R8&BXB{1ZUcrgFu_wHxE^Q5^^qnT(Cx{c;D za1sg$VNF?TsFxDp7o&iGU{xvM<@Kc0824OSLB}0RgT>|7e|rGz7?X-sBuYya@h=lC zLX5Y5G7dx%rNAGlc!f+Um%!p36Q*7ZLx!oczf#fJl-3cg(b4`Dk2TRfK}Pc%JLs(Y z#@N^20{Hg<)*XH(1n89i&B^v(CFk}i6}l==x1jTFOOeS&r8|Aw>a=d#A^|sg3d(+y z5&u}}j>r38YkQkJ$2)kx#r*!o3pfMkSUU4)$9a$2cRXY4IbtzC z+}VA$cMwe^uwQV%Xn6a=xvSIDN?Fsqh&0eRFutR9+nB@N$mmfi6<@8cx$!=Ba)Jft zWWR%sJ|#e-aD|EJH>>-Hfq1-v)h*yisdT_(yfQs?X>#1GJr$J09STDO&kv6x@q|(# zAGKNESv-#?!OH$YzF36(H|J)rOi!}bX#)9zoG$sDh+8>C@+p5RTB?xIq=r0n4Vo^b zVKSMkj;eWYj8FZ0yPR)F=wZn|Ic*xC4DobD(LsvLTe(Pb$u==+u+bts;+!%6X3QPM zgQ!|oh-Ljc)0}Cz^+}Q}<<>%np*-2YU{~qSSw_vu`Wp*03~SAem>c0}ause7AdwrD z8qkdDt$^F;xYHH|zfeVc#R=$6pJpqi{a9c-=1JsJrCJ4-0KhOwNyEWSbs42l-7MhQ#qmYwb0Q~GL$Xl zL8ez|F4`x-e?RHR#(A(hZyB35ItsPQQ$*n56rz_YmmE`8rH=L{noq9;_adCHPlrl> z!8T6g4uOBHklBfN0Wo7g3{X_`9rIZyQ;q6v=k22gxrVQgT&46pczBTTCkt8JIy5rX zjCt6h*6+r>-~&mBkXY|yzy+hP%8nVH7fO^I@_FeNXyZHtnD$R zx8ffwawo%&@x8TiKBBcQIVQlZkD;Lk29zmeSuWDa&aYi zc$f;LNH0#bG8NQrQ|U14k|nYtpU^uNtRpJkOr{Ni7zw^Ym;kSq07J@+a&BebDPPds zeqh}hSj!a?^kRh3okT3F5=ZFx5K~FzBfGxhfN@MFCgZa0ed>wqAz~!bK)1!GD8oCx zf?jDJFpp!eG+V!sO9F!S_PRnlz#>!G5WG2|wBI;BGjK&M$L6V5xM=E~Kp*!5YkPsU zdnz(M1I>MRuPG6KA8(_mlFXF z4#U?qOh7mQkV^D6?eIl(kaEZcc8j5XIw+dqMW`p^d0s1K2+oN2OR-EnctLI%6#ymB zMltHCm9wz==p`Pi4Em{Csr?5q*x+R7H7|Y-Rfvns{l`=)MBvd+-$okkLZzjIKkI%@ zi)4C-W|Vvyq#~8#UPx8V#R<}J%t2ZJMWiaZX;3|M9z?Vjq3`J}#XVX7X04DxH=$mU zLZcmBD#m@_PYC+@wBy$#jr{@0IjNRXM^7v1D1m6H@JhYRIHP(u~;g5BXI1QapXoAOZ-e z8nIcHrW+aLFCezW=25{HLAp;WwceOlI83!d$@hoHwOk29BZcwvTNgtQH*!b8awdxz zCo}2|7bXnzqqR~w`gAvY;4dapxMoD~3NblkxjH9SG%fB)6Yywx5J;{al;VkMp+s(l zMjhN5nir?Us`?|C^b23g^CjvOvXagKVL@G32S_w(c(gPwi~=C5=8NG6>jhsdb3i5o zj^k1#`ee5p&x+M@`#Vec#eVZ*BpKg0$Q=bBAso!07vLk+YBW0#p{!rYW-6y~KVnu*e4G?m$Pkp%%qBT*|(4KC4VG4gab zv*RwO(&!X!bX-F+bWJ}uMjN6cfdg>c)V?pZwqJ^;I1PZzp_>`Wx&j3sF}uS|uwGvf zYsmWoY4Vkhwsn~dl%?Rbdpn(o(NNo>v*fp}2r=3UkNi&bLdj3a?&egKER}J)?LM}yOsEx__%mbmWv?BY3xy0V{=*nk4 zibd*e%IPZ>`=~AzFYG>w1oo3q^3htXe^Nhx(`M+uf_>;~U-+{L^3pec!F=xB0gXXw z8`01Y(u^%lWa8!Zdm-!`haD(#9FzKmn|8!chb=1EckfEW@^R~4$Y56*Y)aUNvg$xH z0agyo7-p`H0SVn-iNH*f_-v}wuqPe5eJMDywePJttZ5|LmLjv$G!W1WPATW^qrh+d zjA<~W4Kb+%?RLsPBmX~Hs*VgYY zuOM?~bOd&k57xI1-0oa98=aVVckz5Gpa1>iXMgkMSK)X(lS;vck;)9S!ABY`aIPQR zy78kM*Hv;^B9r}cb?x5sRS--J4nE&Mc(&^dg(HPR5nuZy66vG0^<9tmUw-f&1hO=c|xIFI#Yu4 zM=Wf_qQ@JXpD(X`@A|bL-Mp?*$P2~d;~nSUe{l!3M*91Kow>r{wOwaAmA05n7skhc znf~Grp944pLH5?o>vJ=31PL(i@%jdGh8%XQ!LW_MA+b1`z&1>!Qu*1ll@FFK{_}U= z9x$6+p~$bkzUTJ)vHh6}mWn6tJbo&ch`)2=`piK8Q>XJ6pMHiSwNfdQz~R~$(g0F{ z-~HX!_m;PJlBsmDP=LuAw)@Ey%FmadUzwTy$=kQikB(5osKa3BS9k9PBjk3!?++N& z>bt9JiFl$^DHG+$WY2c@{^Gs2e|Y_xR@MId49GCPZO5I;%zp8dedKih?VYbzo%@MI zqEspa)xV~3?7KaM zQVEY=wL;MfN3s^*Y+^qif}2SCN@>PdRFO22i^DYt;BQwV6wJgXjs2{}y+q)Vi%d>D z*=;HV;7F<^e5rKCy-&Sl0>C{S>>Mc^3H%WumG8SZqq#((QUFi9)H-8IXVe!f5ZIA_ z01;aW?0b@A^*^AqBaV$lMZUJ-Vg1Or18@pVa^*aAm;gQi2RB1Ui;hVI=0?~P!FbWj zv&#rl0`TY4P$s~_K0jL93Lhr(DYy`2AP5mYqTC6)KNw$HunfTfHjqs|^X))m;1kc0 z9(hxdbTLhZgq7KT!WS>5ZVb+`Fy)Bh^QCLm8TDj}S?qa06`1QD#r+D2oYNBU=pMkA zBYG>4xYghRV*5bb2BjXzpiNb`8N^q~W&fT(O+ElPw6-Cw75veW-Kh7DYyA*dBplT= z01nXLRckxy0xtaRsii>;+|VG8Kz@HXc)lLmM+hSr4x^iJO}H2H3}`L>Y%E(OZyyDF zdDLK~a3nC$JMQ((&|$KWE^xa<@#Z5J%hu)U!{pfeV;3+b2;bK}w!)q}t~FOGkS=7X ziYQC;Vhg>9`@b`J#i2IRUWD_hPxqHMLkHh zlpnY!Hj?l7o<%bR4@Evg<%~NnnA8?b;Qhe!68G(a3Cvoj*C{(?38s93<^`vW{$&&3 z)2ZpUtnz`KtZSv5@*kHAJwQ{T5TIOASuuQ|8M~q$ItRNk*uIt1k#fR=#{-y@qAsb{ zBGuV|simCvs@VvcQYSsi{@GH@o%gK4T#jQ!8iav;-be1by%L$yItJ8`)n-(%2D#beqn9Bvu;nzF3SM<-DtP8W=@oV&zouq*f-62lP`cDyPHR zi5u#{1zZVW0dOs`eFw3CYWd^|T?G+sNMT-DiXH*%Bh7$v@w@l~7nym&oJ=9&B9nt7 zx$@phI-K`BC&WNbvvCSXBAq^bl69}ua?u*`JzWndgo_B+XbF-h{l;_Of?S})i*ClX zoH<&Cdbu6AL|_;y;#Cr2RMG@8Bbb_fmco(5%p~DU?cFOUF_d|_Q=SZ^L{bRvWB3fy zw+e6@p4-IHoLzVTTp?ylt%7+my%-UJ{W3{-(UYDO+-Mwe(4$F*wK^E`XB!u5g> zgRD;?EhU05)vIN=UWjEXN4*!XBN*z^i|CavB|vP6)!o#V3#`Nng?h|E;K*V*^yNye zh_z^sM6D`C((x5%C6zhGwrF%&Ag}L&rOJbcpHL6A8rCX1&T@u2()+P16C+A>CH>CumDZ&cznOTSBfSpsch5JBG+hg z$CcUl4gKI2U{F{u0@DnAMP3j9q!y!zVjxan#2i-j_{<GJ*VK5)-EPvmmi!ziS7JofjN3=3mi9+%&GpXJIh;<$Rb3z5_sGz90Hws)6|3uAQVR5FF|!}ajpH8M;E;(#6+NNv00&-YNb z7$l2!<$z5;KMpvO*Gw}>uN0h9hRp&ZWcx*Xrz))+_}tE(2A@t%w`IcGWeGl6+e}K( zwx~4|_*5ZZ1qelJNu+%;b$cUm?4Fiun?+b(!)$;bZ-g&cYZ`1v=1m5NIu$BVI1<3q zz@+}-dxPYTPyTNnK&-BPZ}`%O!%D3r57f`1xbwI8Zce^9`}fnnodlq+CiRO_X1pQv ztL1Xh*1d2KXQ-&R%4cp^O%C09qj_9w zPyW%3d>A&S6L_Ji6|yA?7{2y%MR-5EXl{6kA3r@$FP9m3#kLm!t63f#k z$m>C5?ZFuX)l4czUZ2^4i023(bN<}{gHe|dOP&oxY#@V& zD-o=Akp}2_;oW|~sam5PG9y?vo~}Stk&d|2lK1-|0Nzl^j%*H&)Y}#FH|=_>0!qQ& zY99A`GO(XGT&-h~umrgIBaI=SF1ptPB#-q)_OhGCtkZ7EaQgi ztrZ#1KP6sM18*3Pq|&_FfT0$D)RAuG)0)Gt(nS}rSIOk_=@hw=H=KL>USBeiq{tw0 zU@LcTee?hN^*;nT<3Z#Ao--Kq*<3yaL@JTUq|*S=rUv@2PEEp0K9$Q+R$Hagp2vgI zdeg9DFPB20&?k?c^qb8iHY)^PcJBR$Pd@Xw$?X#6czOf*+6mkZXq~@U{CzruoyGs7 z;d>^H1~!i$KVA8&-~2WlfuSY$l;W-pe4I?gX@w{)?H&-XB2N@ugQ%3i*(*XJq~3%IsG09*zRB60jOoyK?B zfBmfwi9fmjAtjgiD!$j?mU6nBV`5I@|=fC6s;b){_6LiJXu{M?KYM> zK*4gk4F1Jpv0SP6Z@&Ez5J?cDiBvj6nLJuu$2I{7fP5U_Un~|6j$HrvKt9pme{=TK z2YgCHe_Q7RX}9HuvgP*v;`2LSKY3Qa98S(uaqrC6; z|KiTqUp#tJgnqaKl<9POWn&X2Y+Ohj@8T*fTq> z?;kN(@4GfUW%8ZVn+(DFZ^GbUO(6bEp%LGSY?|f{$bd&8{14EgzP<{!@zSCvzP-;6 zmLsqXYV6X*kO5=LmkV4aN>%^GR_+oud2sxTl@fNTp#XE^zGw4`!{^C7t|_7%dXS+n z1za0NlzF~Vqza+6pcVj#v{#^#b2}p!N(rlVerR&6y%T)HY@yBSV7C0h-BXf`Ep9iZsXsIc-nD zyfvS-*o8@zSmU0+c_Ig1^-d2g4Vota+hIO9{VU95GF?i;+y%R#Yn=5iTEm7sKt^WW zfDC(Y080exa30$e&6g*f@OZ2H*+Lu^i^w0gOwA2k18RvETRzz z9Hub{hC$r)J}#CsbpE7EDeUPkRjT2nkH*76TXVzLC+rJqnXXtZXk?n@!><5hw%&K! z>XxarZRy!N(!F|XDZ1144W|&=Pwn2VB!j0wJz#RBJpREf1XBqnd+-&gpL#3uYx%r#Dp5D(6`=02$(& zpOs@CCUc(*Tatms0Rug*)@WBj~-BFJ%wB zo?jIz=_~=%8V9;o3dQ4v5{wXSH|CKITS(t0kAcCv6_+06v1(YRWRqhzg;&`|rNU zyU9uk2QW2>Rt4k{o7T}{!5gmS3osG2-dG?rwp^*fq=1-MmXpt&)oi1WkP&@}C8H14 z0NVjfgv=a2B0R{_8;5WMtQ~U@gb|knr#%orGo>I5t{ZDGcs|JFYq?@DkKh}kfsd|& zrWj9qe*INlLV%k}1@KlZzR_Yik%pYO_!Ksy+)O_DzHkuRJC8-iGfaRxdI4=I- zuOs{l#1@ev+P2&v4(@jYPDUO&|+?W`ROOBAkG3De%lc}HBslbR47KLJ&kGkLl7K!^k@K&IQ zlyfdse)_2gTe)r6dOWBaSuo7swo7H$$N+^NbN9dI4NyP>lnn=#fLg82Y%azw&QFXc(%Jj#n_iz61gS(~ z?cgw;$tqaPuI?XTUpC;#EjPXgKcJ8( ztoAO<%-mR*D^{y8qTJ+U2nGXp*4C~}jl&Nn69cFj1C(-GdAv#8U0Yw<+QK3Y@Z6o}tBVt37bnL5*MITP*IceopFDGOK&r!G z|LE!!0J7L3jcw5at~iA~<{zFdKYqSOggH4G8X9`<(nah;_k<%~udYKa(D~O-pF51E zA&W(S+TamNuv=~KU%LuS5Qe2sH#PzPVSVU4Jc1g#z>%$Dw9UcJ@n_pR6pNe~8-t-M zHouq4m0vx0`1Ru_uzv(By6g4DQfY%m^YzB|gO%r0r3Rz%=HmHtBO|a{{ruVT!Qm0Q zMw;mIIOUdB{3(ZY_nRFh?sPBnr(^z9gc_j1^bFgm7wzNNr0n7LdBKyjI~7S3vqpuc z69=l3wlais$@TC-BwNQKts2w3Z5;cr1L-*OQ`{cpwkky?RRj{-Q6KCg;mQT8COkC9 z&BODy5o}%Cinv^HU$IgKxOv~bVN)6}_fKH*x0R$*D|}6YhMop8v3)#XIV=(;tJjCl zFIdL_N5cH`2^>M^2r7!*&Cy&2a3r5I)TNRs&)cW)xW@i>*gql%1c^LWT#@ODLm`nZ z2lsa(E{a70e_XIlOqmDp8G9Lce*O$8|y$#3_OZ+k7^q-rs#jpytW zCWYFcjuW%QVg`Jk`?qJzLneG`e?4EQ+fk1@6{J|CL1lPr^!$*<3Si(X*Xn-M1<(^6 z+=~Y0EhB&Ky7v$lyj>oAB{{CmDsh+>iaUQk^L#+?$h&E(f*>9&+Cvr7}K zAvR>l;a~K~2OqdhJ2!k2;iNcvgw@if3UPd+LFkxt!(Ip2}|TMHaD0u7{>no_Epa$J5RSR@34{EK7nAgE4V zeuXZXjbw|dXeN*=VqY`4xFIbWos*taE)J*b93`Bb%0{g-S_SN0dsQOMn0;Z}zf!58 zl)E2Lq7J`IrEDNN5WeYbI({u&kE`Y6?mdP@N-d)b`vhX56@q&-vzm>i_a3OQ&sV0D zn};-$H`3d8D9Toh97dLZQ;Hnatsr~*^t12C3_yzeUs9f>G^BPu9I$4DOM9 zeBSVHA5Xnaa4wEBeOEdz<^lgH2hYPDBy2(}sW3OlCsi1Jrh+oVAoi+f_8&1UQlwC; z$FIWuVljdZyxV6PV$xhga=y$o4Qoe=MdC&-R%>;0A7JkU&|7R~ujYxN@kyne4{Q@S z5)MgszG7G;P!jd{b!=yb(Pj4NDKHe=1_@8vG5kfGrOhuYaUa)EuUKnV5wq65LU13z z9~4rhm>j@-RoDcvV0Z0ctOFc29pC7Aa^3eyY!~_o|g<@*W2{;l%mOcuEzg_{7NJd|y zu-;mfTC@dkG_mR|$B9j4a^oQU$UgX;n?M+oYx@+}6Dj4kx0iGaqd*BWhXMa@?-I6| zLVS78{_fJq&prZ77JaxyeI_Ebi}B(#d|!%WGKo?adA7rlKAG7t^rN?7%?AK6@?b5s zwg)~yGm{(pwwvdm!1|0rF7DcevSQF~x-?x$=VQy8_@G3F6`!DUMQ|`(ofpexDW@mP zIcfyHnfFBtk)&$C!XHcU=~GB8*XOa#xthwwo;wTP5Hib=MD+2td1*$fRTaX?yeEiS zz@`=b)}bH2%^CB?LZ3e?_#=d~CzN)207t&g5SLQf=o8pGlY6*WtFl}>2cMz1rm}lJ z-|z0>hK5hHg+MG9jHw5$g;*-|)iT8*$@c=Q(Q@;gNFjsuX#B|*&r+4&*EZ?13oMxJ zZFqe%S*+bNvS*j!%T{cw3Iq>OsBnvw(har%(Xa;@l7Lvi9~-D?L}@A~;#tGI4xwhw zD0U|@)*VIwhZAa}++eG_{edJpvjY2iC;{*(W^zb#I^Dh)!#9$^y;z_S7)7ehb;k~n zS+17I6_PCVVu_VywMTDNj-EFm0(B~s$M^1-H=7LR5tLl1ld5%6pqbNGtdp0l2*-Fv z3E|fSp}-Noraczfg-hcJl}(p3oT(#dv11wmvf0B6wxY22ga=O)!AKSWrJ%ozJ;~=Q zom<$TgLvQ+IOmUuNujqWDYl4O7(ujUu+>eNVSq_NPsR`lHT)~-2^NV%t`^Hxbvs;Q zCs1WyE@}Bc#b2tcwwGAj%(i`KOUW|3G(ow6veS^YZSGH*Z8gvaj(nl&+cM&3f&&{! zem+0lW$<2Dm(Qnrx)UpvO7qsz;9rcYLIKDR8=|KtC&0~OH2nDHjc;APgfBR7 z4Y_&b0vw6_LOvgcC&$wVXwJ_}{VzZHQNP)M@4Gs+>fikaK4XjcFq_KeutOa8##kyv zUP|$)nF|^-X0r~M&G1L5SF4uy_DPF6Jz_f`boG;6r}N+dIe3d$oSprjfBffgfC#gJ zQH$lT$`$9{UaeYn2ZK94zW^1LNu-MlbKk#lt>0|Y$Q3rd4t^w0OM2=I#xwKT*8Jeo1)WB{vb#sDI&lNlYHuyf--@Uwq9nG*Lg;f9H<4*yz=JVO7JG*b4pT99f3?A$6 zUU1B(*MINUjeqveTiE*Ti9}&_cIUxE5P(}WmBt?MPLXy_%~kH!+!Uyh0@jQhxZP|r z8}#_B_|etNKfQg^pjN}Q%pHwB+1MoU$9sFTg#!M;esKNBmEx>uVMO`CmCOIlw?42L z4f%2jUMBzg_n#hep>i6$+y8Byb;#*7;TbVY2$=3A{Me37Gwo`8Q8OE4YKO)MSd9Wz zu0+6_jsuSD(h9B4e~{6f-KgIKh>bJN0Iu}**y3f!lu9B`m-BL|?2&sT!)*r^h(u56 za~A#Dh#mzq5z2hlHvZxGujKc0#2#y#Ukt5NKjD8e(YNuQA>lnx8QC<9$LYkpN)?9Ep6G!xS4B3h0hX>i_^i07*naRO5qJjw?e8 z?~E+KcVW7e`_Jpu-N+%jkB^p#VlH1TgJ3`I2YyPcEDg+kFt+I6Of36Vrr$d6`;%nz z^(!C6S{b{lX^?m9VZS}JU{aESi{?{D$spiJa84GpWiCAETI4X1uitYfEF<3cjNx>#IuXWgl5iq}v6q7dIrbyAf~Bt`koR(u*hL`d^-i?-?p;W^P8 zJHRU~q9DR-3s5p#F_N$_9GbyuX%z;Gejp!)r6loi3S&vUO~W)2pXZe_wNkDgFpP~@ zXSwJ!r+NdkBzBG_s3923p$YzYwH;7K&)sk9XOQ2f9}K0uzW9Ewm?~FFuqr@|8kjn9 z*difg@@EjK<^)02NDhLAOT~4+grI>@r&_Ac=Ii#Gb?)u{bQxIzIlo7OB6j+Mo^k^~ zN2*J^aQj$E`w4!B*8t<7VfJ0@zXpV)v`oMdvX;-{7FW&0@*$^c@Enjt&G03JHz^Vb zk~r8b26k}?(dU?I_=0+19yEB@k?I}Tj+^Uy2ABq%VI`;TpC()GmmNpiSZ`VwNkne+M_|kf$p)^a8CoUOwbm*oSgrz%rpdSRACtn z-TRMlWr&LS@>fcSWrl%5_^_JA4iYj`VDx9W@O_1O7*>-C^9X2){#{yxQXFk1N2Wkv za&>dVfhw@;6e;QbzgIJPQ$2W&GRho1gTP!G{wT>-P&n(v-f)tGqqk2t{{ht7Cs885 zPp(dHd`?U1zhrN3s2JMePJ0nd)O6u{7)tb4VH_qVlLV$bsTB*@DsEi(s8z^Q+{nyt zyhJt%SZ9fMQ5ZyWz33Iou|HjEaDYk%TiVUfkKs8M2X@KiqVkhbt`{!K1L%)_?tNAN zT%SZyv!q0L0Bzr)MS@?}783)B7s@PPA~=1`>Wmw}GSkxR(D&aGs};FGtQf+kW-?yT ziNv5C{n-afn-O>)oB#wqB}+KG(B=Zfa|KKzBXuki`}$=A-@j?Sx`3U|)qS*kYMLu@2!zV182~>zt&N-Br%p_7$|wB@D9+o^=!zIH07Z+s7T%P!{aPSLMjO_ zQWlA)CE`nRB2$RxZ#ZBhx&0u3DB*C|2yhKM-0=^e;3T^u_hvHg&tLnqQMH~dVqbX3 z@io|_0NWQ*f&g7>Jn;xQdG7Z1Vk0(I#I!m##z^H&B$Obkc@5fNe`OoO}dFQjT6FDmNU z^My=1pSNvJK;BkSl%`cALEW>To`uy}XwKf^|34|o)x|~>2^>_dEYVgkC`Ep!0qB&^ zSTW{Eury_eQ%3d;CGaLTKRNc_fA8B<0|U5`p%v_s{)@M60fK}s@ahiRJ$PdbCR0zH z``Btt#2_moJk!bB1^>|M4cF#o=SK$7WrMm0+mOv>Gc}MbRVuk+$!0Jp<@L`v0A*nJ z_rGT{#qG#CxH7tdXQn?(`1P24&+3z;`v zZlqzpH*B>c!B^Ax*vP;6&WG1$W+`k(^iS8*9h7M_84=tDC?s{`czJSS_rk))-X69~ zLx5mBp3LPOGQ(+;dMxC?xZ|zL9f^8FxYd(mZfx|Ui;HF*2@a!>N#8xc@caP2A_lh+_akr(N=NkeMcX(m8DrTvRz#spJXOe<6j~}F5r=p4q1?{+ zJE*h#VEQ^@QiEP8mSS|=sZcnVUJo6{8Wx>cgmbTkr&#J#x>?f@AYC*LSa3*ZamD?G zW3nx)2y@xvYO-8KtU1(04glYI>*$bc#8WOPGH%Js4{fX zrj;qHCvfQL4`mXofdgu%2_Br}WMaf}Nk~PxUUE!Mn+N0^?f~p{);bc+C9}o)#K6Q3 zXOpR7U2@#S3e}1;7NB)8F>uiILBf|V<|rJAoYt{m{vyND3)STUa0$LYo8TiX2!$AfO6dNzwr&v0~sOC5|;leR@a?Ln`|e4ZI^( zn{`vSiUDVRCJ~n8ZovI*{*ciT=8>kEHef(X0Z`PDs9_`+tmIQ==rbFk5+WbBZKUDq zv9Z@1v56EK1n1!{4VV@pP*28r7jAKy4fmWB5^m0oBgw;1R&p`?Biwq9o8yAryFic^ zK5CVm$BHCLM@}EK3KB(u-P^Y5*vz zdT5b`DHkhEn&I=A{Rh}e->A~v1rsCi{XU~ZUR*kxX&MzWy;^0$#j;1}bSS-IVa zJf*{=9JEq65*G(7rF9d77;Ujqu3wy_a3oX-t44@XEtaWEN>qf}04F4!b*Ro7^bSBz=`Wx+c+51;kD`_sM(pRt85}86 zDb3gBWZc$i&5%WIGf_Aa+mpPcZdJPtz2&BUo@9gRNpkonkGmicIBTg;7BAFQd z${Ypqp=4@f9}qBQNc4|49DJCH!C5ECM3f=rpoK*H;oMxp_S0~2D*9-nxmBO5`6^TE z?0w)UF-oikOY_@hnT)J80UCalR)W?X%C@!E%jUOora9QAV(M2Q(g;q0lp%TvTodXo zWVl(Q0Gon`se@;AbRM`Bd>z8gAD<`jnZb*!jd;pHo$3TU)$yy)G+H5jY0w}HA~rjN z8}}wbny6w6OsRfbLpW@>MB+2SD{ZCHXb7_!-=7Q3Q%2bwJeI}lATO_7lDn2sz;(FM zYh2>Ih;VJH_4YuG7vAX~oY5r%d2CXS_%eaLbm%Ap@E52YTr_Sy3=K}}5cj4{^_DD= z9AA;@MvH7{_*76D)f+P5ezDBCx!@B>VL+i0!+|3*5mX2NSSjR*2C@i#jT~s$=uoQ- zX97v-zyUtS-g871L%c6E4bDJ?86t@Rg}ZjzmRu!;LrEwRL(W*UAQCNVUSVCDhaX(6_R8!G_8n6!60~Z$Y}9VWj_LcG+X#UjPo)vA z1lEqRWD33`yD-zJRfwR?guyOm7)a7Z7T(5T>ieYnH(mEl#=9#MROy= zOH)%)NkHFd3qY(f-l}sV9HScCfI00itU~5VYsk92EOfIWt2^}yyr;t`A}FekAP_dvR1CYuWRTalX|O8%1lRsavXG&1 zv{9sn)j!z?T)(SeK;VQ%CszTEL}{v+$#705`SS47eCFNLdbtw7+-}Uj9dpOB$wHMF zron}Es*vV8d~9W+ld08G;+cgW8krIU3M2@paIT6^4jEkc)%pS*F|2h`W%Du-i497P z3^9?Lubr(eydVdD8{xHPGweP{1`rnrZQ&oHPNz#5wrHR%5KtRthIF1`t$}dc6pKVD zErUURr~+efO{ zK}CnkfZ*)wp`*iO2u7x)7sSs@F`dr{uGU+?fb8QClJi)u2Jiy|y|GFLZE)v3&O{jiS6uh{WEjb#G-2-*sn7WW+5oqQ?7 zx7oyA@?4Q@0>IsXT~@5)RR2^++Ab~P=L?A0DNC?HWI5dv0-dl^Itj_GyKz(y-iiRR zBnP{dYk6{;!hx*g5{ds#PpgF_Mj3MxIWT_IPTx+feZpWhbXyOzO0@-v4*en(+ZIv& zFZ@Ma)(a3zG*;LNX1(i>8809B2ltCL1UM*U8mM@kC2f?R;PJu5Zf+W7h~!!XnV|{+ ze1;|vDVqr+GxXCDrA%Zx2N9yFJVq!?)EjL_iICHkbSUH5s3Hgqd2@mX<04dzUF6^b z3CRyRPAZwE5xPjBYeI^|(oHFn>#F6ZsGu*!{ai)zQ!8fSbdsI2c<7|4@nMgZzH9rE|dYb)e!cL+lRsRf^dZ%XuKV+2?T6GFY0^Z(91s33+-`*-;25| zM9tjX!!JwM%TE8HWS}{n-8}13)FkSZYH%fgSsJl}ID12q*VijmH@pUk`>-3DTuQi$ z13^{$cMnJR#CdJdJ{%J#7%cx(63weY)VC9${w?s2UrDb1pw1W`a>B((rzm zkxKyZNi-^%NefWYc5R_gLRK0ad&Z2Hz-Cyx2{f_idwDm2@4|5r-XY7CY!09&k4~6G z6}j0!zp|h&2%pU;P^2*H>!%!8*(kytLOA1L*r*w5f*n9<5o(;mk?2=3o@cxnjL<*=hx#LYEG*BW1 zXt;(&K)z6czfNu+2BGF=5xjPm8|I?DSnHVd6pJL~u;+(hMy&vjJY z5&vB4i3h(vy!E4Tt*L&aK&E>u_V~9h8jcwZO#>Z)wzQdHEtk~6Goa*Ulwi_X=yrM3VZ6`ycMtzG%AQOUBiG5uG$fUg~4!n z60w@BrMny)YBjOoX%3eNd}Kn_kVR1*e5ro!CBpnKOYmok${PzH;6zdlHf)IF_iGi zuaN@L23Z1uBO?VJ;@)t>3Wt^PKz&eDMj3+n;AYMqQq9i1C`FqF6=lm$&@?097crUn znMm^wRAfB)`Q6jqYw_C!A3p!Ka?-X~5ud)h7=J2#gBqv<2XBz$3t9aaA8G6e#>QKG zpk<;K%at#mul@C>pLv5p#HC?jwXt1SE>~cSF{2bDr1*X)ZGOi9a={<~Zp}cPlwX{f z__sg)^WQ&Pd9uC{iAFKw;c*g2C>-7w9c}LI9|c4I>ZA8zvWM6)WNig}P%2d^*m{m! zONhvaf6dmLXk?FEhW#;hr&CR1iLPpVuMJ5el|G3AS>61h0;>LeFi?L zm0(-OIhDtYE7Zb398>_@DU;NfbR@V(sTl_ueFFik+hn{opBP#r#wnx-H!ALJv3J-gBdj(;@o*fW1oD->PsB57;0iDTT+>iH_qW~%Y(qJ7QNgp|j3S(s zjDM3w`XK3tV@FyUVq+(=mpJ?)YRNaExZo5)pAjvZ!jYX?0yo-mMu47>eaa2}0!Ebd zHeXgb8~^4ht0hPv0ttiXTC2z<60BZm6@Z5LP$#H<_EWxQGUeiDhbyary?8!}9og(n z0KPSWdEuLFC$$g>b$2MH7M3Y?Fv`0=WFy z(JFi?VplbdF%52L1``pIkRfUu%@lVlX=K;`;`-uAPBMW80=lQgc&DKDsKA3&2d z?hUQZZ;?m!f}|%E`SfskFXn~uEY`Jj3e|hT*TypP}hvbKgmQi z>B|&S6pKXc>BW(oW_>>ZN4S7Iifjr*_QejuX+IsE5z7=dZkNo3LMazc`*wZL%2l#h z%$HK7Qt?E=ANC%*eq?Uw%C`4uDCq(4SL-R{&XjI=%bOKfh6Kz zDU!$?^<@tOMPF?6FW)13nL8Y?Sj@VEuD`!i4956jwIL~@xSS4vbD9MxWK9TCum->q zzd=o+7&X$Zf58Ptpgf%bzSPuL^}rX<7idpuHDWh(?D1A&%~_5Sa|9>{p?H!gt&1L2^20M#*X;FXHVm6d*x`7Dw7bP4mM|=UUiUs#*+Zcr#sk( zIm^0FXgeKDuucT8me0D`%AXaz8t2!k^~z1d>EYx>^T9>)?qjm5&m@XZf9FmFbK@5+ z28$dPj=_V>!AiuvL*O$UYJ&_;4+o~n_6V$GWy<<2<$$)>+;uw;E`W(UkHVQ`H?!+P zwh}r@Cqe}{VPu#l*FsNVKeO@&-&;SCp)L5bgkXD5>X)j!-7^;@X!!ppCdc&0MqU zttz8kiGT2u&W{BRVFxi!EE@(A0uz5FNva1j*8}Yk0?}fJRhY;}-Y;B{63d!Bm z8n;$trozSWQ3_kdCGykPV?!*nLunpV0ggn*;d*@IeqiaNVMG$oCd=zz`HxoX-&xIe zg_(>0)2R0nf9WJ|z3yIu#TzAk<%{l(O7SOVZ&Cx~Vadi72{zl1VELUId?5iUQih!< z+vf8#`C=fyJGV}93i5$NA;Jtl6VYb=^PR8OwzjbV0t7iXJ&i5c_%I2m@&4BKKYsCL zQ%C%2oa*qz=H;DUTXN5Jol1#4zi<|LboH_~7IOtdPLB`qxM6n*WB?(vzkIqpY_r1n zkpY6(Hj1Epj4i$$`y$bUU=TUjp)zPPLVnh&g!t^PcZaV!R$Z9r?FDoQ+fROencx?7 z#}+^jVOP#CePJ3VW(v8CKzsPeN1_Rr0dsMH2@w-$FJmr~$qdG^V@2YD_sxWjJ0wm^mhpV+rO-rwJU7>;m{Tdo%DTiGbU zHa0?>NDk3<;M_Bj$skNHIR)`4d1+oVi_Ke>piBh6Oi1L@xlI!^Ka@}UQc=FSAFST+ zPrwAHunP1cENP{rV)hpt5r8L59;#co4I$xlc?bG{>4Y^iKuKhS0LCf+>?jFL^RRiD zN8Ub(SS8FH61%5iW5}n2M?J;_ehe9NKFt}^d|f>c>^=5wMLCz6SWk_b2gZ&4W(7<; zd%rq*?nZFc4hNX?)@B7Kw_ z;8%7^g~t$o<=J}X-J)?_U`#q?8Nz08y<8Q{CLVdWsMkwc9h=Z$6iL&tK1<^pL9NP6 z3nOq|-LQqEG%AS#?>jVY(^_V&qgq+>lLmP&ZcNX9je(u=wMy%}Lvo-VZt&9r$0(ZID*wJ#HS?A}D!9997n5|}SNg?ci-12 z^%_OPNC$L{qmu^s@>5-93-D-MM+fAPUuRd#h<$o87X`%3#R+^Ya`2Jd zF)h{G+X7YC3Pqz%qlH=_ojF1f68??xB8RNJ0BDn1u<1P;*tw#kB^f(UsUV9cY8$&mE1lj(V9e3bw03F@bW?YUNO%j zL9GD)dQA)whMdR9X&hu&u|Zg_uHQ%C#1L0t3*gN%UM#dezAbUi}ft4l9C;7d=;}~=A)Ob+R=1PE` zcM=ArU5e_wUi61esuvM9%gOF&Zsy>(iMJBdc4Tl*47L($R6;&tzYuz^2p$;wzk9Pz z+@Pud#is9ZVwMZnkoKGBHDmpyXc`86d2bZhA`JzDMGY({pIkK-l_Ciy=1L zR?;-V`J^NtpYHWjg9?+@xHMggr;CAj&Ku0RgItVq6164Ol)K4Xx9Zrg)+9IC$y6r* ztE|38g(kP|;qH(doHdLungsF)#!+XG!3ePqN4hhg{Ppg$Prb194!ICHNBEkC=sFWi z7B_-id$Wa1dG$+QJeX5xq!{ABA=ZrK&H(q6gVTBhqCrC#FBWMhaG3t&-|cFQvc>lY z?4!+RG;Toj((IqsPh2$9t>|d_-Y<__8wrdNCWF_RqNcT4qhjKcmD#-tx^~9XMuTn0 z`9s8d>~`jXdFBdS!K@Y@5%|C@2 z1Gg9EDSud}ABMudKma{KRX}ojwg;!d3OHkxGh%P6YHT6LUS_pgl}6kiB7ZWNNDP?k zcXNL<>WkJ*{IpuFRj+Ti2q!JonRgJoAT^{Q83)OK##FgxXUyCNHscIsdB+OHfpfd2~JM-6G3dN$Gu7_A|u}E0^~Dno5n1S7c!o7m|~H@ zfk|G(_LS8$wJmL^6JHjzLkydw-rfSyHD z-dDAmcAB2)F_I!i3Ypp19c^sHDinItgI@Hq&w5ZOLJ`uAqzSEdc1PQ@J3BM&r>ncG ztIKq;ibb;M9YB!q-jn+7e-a-KuFF3YAbYl(S&Xc6&;8DKzI*>a6PY)!8NN8s)Q&aU z_&zo~H-35S&eHZS(y?)yFB!l>uQL6binEnei_|z)!AuxT#@;bjc=7Is&6(Dv^`(iq zk-26377d7@#7q}_WN&WhL>XH&!gFWr3ci-MHJaDy_}s|ys(Ez)jOp*#HC|sAIL)r| z73R#G{`Mo-s!|xnX6|B+bzDo{c=!>f|B4K;)g;G=8+IJu zeDUwrW=HXP7R1T-bn{-kj<=(Wf7f>A8?r5y79`^GvBPvwPWQqI?bF8q#qS~YV(CZLlA?L zzHtOi?%<0C^625rwfC?!mxiOEHO*a%cc2{C81KK8`2)`rfZ@yekqc%fW5pgEAhY@F zQlVZtmf(;f7Xlv<7ssxvvB=KWXTDOE#a4#(7@?lM^rLOz7;(zACpRA4k2Y2pX3f#w z&nAVn+3O!*l^bA6;K*}M{-^`+(ct9PXe_kT@OBSt4D&tG4@Z$2Aifp+WZp^B3S z(M%k-J^ksGv7ec52M<@)appBP(Byy_zWPdCOAZrD z=kGqiaAYCGyn|zB$(ovG^Jxcw(f-o)gE8!aafqPU9V2tT}xvwf1v!V@)wZlsW3 z;@K9>>9xH)UZ*#l&Z2^eEW;2z=^KRVrUy>^ zJn~Z4=*>Bd{b7Xn{5zw!KAXh%Cv0ObQwVXA_28M-u7gdcCo@%z$uTupdk#G_V_59G zhz%sV2i3rzPiXu8T+g9AnuE&K#r5IOG4ImYxHPx~*2pZz77v|mpB!0u|F3T0V$F^( zUz&P=fe3ts!zDu$W^NvNrSs^^opN|L_P%WDs2ja*K8#`b^6sTo_;D4w4&>bT+H7OZ z@z=Vr0Q}<5Mln5e{LUQa*y2)RM^PD=#x7IGUhO>ga;LTXC&(Y`qv$I%0^my(3d?w9 zS?$uV%1}~fv0|SVpG*(f!;zuuWT}b6P%C#5VKgkBc#*dw?<7sDZQ(Jm(&Vc}oR+o1 zfQb3!<@x!-uNl}e@Z|o1#!MYXpl;nCJAeHq@x={h-!!mIA?8>9)pPaw&x9@#$ zci4QCsT~@`Mw7*rL=m$yvp@d$+{pt2&kqga7u*l8UH{QXA1@{6W!Kc!Vyj1tElTH* zW^#Jw^Si@G_U+4}XRzYQdEgZCp3ctMSu=q3$>*2<=EGn9;Tx~z>ob#cbASEeFF(3) zQ6jqA+xK^E`};?x@@^6dkHpR_EiIsF$=_t`*NK5e=%}-bk-~( zi9a$`H7)fy>|=;ooEWNm|L&#EY|HWXKFo=`K6URLHbmJxRk=IY+EiZ{>IEm78R?%7 zhp&t-OmBXd+mw*>{F#n}*pX1iVNubOYwd69z_dXSULL=T<6kz><@_Jb@6C7Ip1y}; zf^a{4a_{ou+EP!x4O=Z;9lwK9gSJTV5#rIKddwu=+$$5!|9tWej*aff8YI}&p)c3b)6k0RiHWG!aB9%zxUIPp%@|O{(f0o9lOvN0<`%s@HT)MB ze{#SaAYO56a`e{Jh_xoYE7wwAok5R5Ha6?LIyt;A--az^us7(ZW4GS9^~ux{X6-#1 zYU$k?rc4?WzYtsP-L&Rc@L8?!4uH|pfqYkYLklXgZ2!u{$kNszsc8lLCpz}up1wbm z+{xD`?*6OKe!Rb>8%b|Y-W$P&eVcRiHtCE1AF%?Wa`DLRq-}x0Ph@d+b5#w$u(^1t z?QH6=t3HQ;IRJ1vE6%hTuB=!ar}CmzF{mj2bX1&lEvu)Dec^nusU-T3L!>O9{3@FzfGvCtc1 z-#9BA9{AB|_G>oj&roMl5&+V6JZrF3BtE5~soaoE=b`2MS8?J9s#eCXtmBB7wgD_h z+*qEQzVhzM#2tJiM^)vKw{Sz?ZFm08r^}!xMC5yAI=cXb3RLR#zv_f3LQ+cW-6&8Wz8yqvUvQ6}@>( zcH-K2ID6y${NPK-x_IyNm1%p5`fBn%R*2raH`6&Nb5WDyr*9}$Ge(XhYpU`c>&ea} zAe_7Ub8Jh7$(!q#?_tI+3(c1(yjbd6_fJ>sHTUYHiuL)4#Ru2q*lwKVJ$v()i^HFp z39UGD8;A4w9z}1n!u^Sb;K{LU;K2P*S=)gAF?^36p}p9%(wNDA#+l5HtZHmqPkvLx z&k=K1-d&x(Z%!Y=(KlmgGFq|q*F1^XRuqq=E%ZtefBc5u-fKo6F9O|foJJ31d)jcj%5WsA(9%7dtHSloz&k-WL(e!e8+$}bEmCAUS{wR0 zP{bcA3<0m|*^c-Dtg5oEJ)i0A#6(owpIF>|=i5J-{rpa4rh4)2gPBXWaI1S0X<-@= z@&gm8A_LoKfvqC(DUJBc1SCZPkwjK(QK{H)5QnZbcVy20#A; zgH6X?={ovKCq~&YGFXsTv8VS)%j3&@#yGp zeC3<_G3D^u$L1#lc{FupPrlK$vapVwE;oIJ|6^71>9>2aaQN<}X$+QGwP}m^VRm!Y zO|d_rd$6wg*h~1D`RJ1$3}gN#j;~o>DExwjLA|zpji=x4IrVl=L%s%Y*L$lfhR(K6 ze0|}wcSbP~Njw61c^nTu=+6wYps)IcM)NT_-g&u6oB_u#lxu(^&ut?sxHl5F9(6vwP^gUn? zM}{KH=<)QIGn7;qB()YGwm)j?4a8sX^1sNrc2Iqo|EPW2)n4SFZiDKl^_N_x0dC9{iXIXhJF~2KxG5I&v6W{-K0tbc;|; zO-=pz_kZ}4PdrIGxui4WH=J%f!v-O$1q}RrlGzzJG*E; zfZm%M`^$?z`SHz9Fbakfu6dpb84k7f>}%{8O-{?f$*IzA*aIXNwqqXUPj8=}#fcdhYMdUs@ZiSk<`Fs@$uckC zPj>A8c=Q@p;)4f6mRHAbv{D=5^L&p4)e5d~q5pzHzq9ySLBZm>SL3 zW|!7i9xP0a&rjlB!j1H5{|Ox6BO!b)>88zMw&UjdYqa7=H$EO)m|B1I=+^XuOJg^9 zh&JA2a0!qChl&4kl1U!3l)}TpXhu3{GDISE>2*W1_OCAb+IYinX7NQH+c)gjsVXtPT_2yY-RP)eP=Ni zg_EqY{0=h;vF4iT>T7d@T_^BQc4JI#EMzhDxUU4(kjWq2_uPlqf3%XU_(u4}yYJ<* z&Dq**Kj2Gmkl8ZKhxkobKGX6{|BIjA{>8%b6kgd-D}tM|UAb)2-SO)i$spY{ZkvZ2 zbyzx7RePZGI0j57X75igjLk34*4NhO8#=mjedq7|Jb5XY*CtLo*gn8wUW6$-pD=w* zH}%^8PeV&WJ9x8P-#Ij&?^>O>i?0zl3t;lX50b-PDsU{!>f~LVQi692?6+9k)QRz8 z?ASPS9rN-qz+vv{+KvOwhhLjmU0OAxme}!bb>_;?YMb}fG-7ZN$1i8)I+|42lI7wL z_ts=GT|>G3&vE|XrV9K27@%1kK5u)Ix4z1{dYo{L5zj|^@a4zInEdPqnf3$N&~OzK zmN9E`^BoWUW4g#?F+9O5fI-x`pP^`p;=$pdO2XrKVYy7Iyd<8 zA$aO$+bOy40A2bR*ASFXXdeq zWbRNerYO!`y=#WmHXqo=annrof78POpml4ZnbQ|;T=_r$x~3`n2ut%vCvhf^2{ogx zdCY}`yxNW?nYXyUxQyd{Rwrh0TKD?g+{DLMBneh?1B0b8a`~(j0U_Jm_^ec!%n5}T z8&*zzdbJWOKFT7>pNbZ+^XYkuX$FT|bnZ zw1WY`_I^yBFc08|p2Y@^naTUu@qP_&@R+!V6K15=ecpef_4oe!(-^Ho73TeQ>^IYO z_Z@2b?f>S)(t^2NF?6eY1Nd+KccIR&3+g+la;5fZ;z}g7)5qH-3NU=nEb5GixTVIhZ?xZ5&&=>jzG>AY;;;ZJ)K^ zL@WA^wto4)tU2*|HCKz*(Tos2+yA~^g)2Zog~3f$3WCimd3!?`*wU70gq*w=d~ zj}vCh?9}=7>YB>>=9<>ddQ5ahm+}RKb@$i|6jyWl!6GKH;*1)s*57w1-`r9E#(z9` z^35IyKsqe;hk*xbyAI|X%&gBg9Az^$Ns|5 zWT#sFjZ1v?%YD5^ngGLO(ysmfYZSSk3kd@MLm+BOdI6RM_NEmENv#El9SQMPg`8m9 zxzrH(Z-_tbaHMnxdt6^sZ%?@whOu6J#0e@Zo$HB&N2-JeumAm>oyP|ce)y^R4uG8m z|NMJDY{+JpR+i>x=6E6m&YhTF!6<;|L%F1In+$(bw&9JFCqBM%_2T6#m>GZW!i5{d z!*zAF3k!>wqG&#sR9BxqeCV~~NAbJX$^QP|fAh7;nVI2{5e({Gx_b5UjqCW*!J&tt zg9m@-)tB4yxe3Iohc~sV_m~iK<=VBIw{GF>7-I^8Xl`lw;&W%u3?0NDHuO1mSM10a z-h#0Q;JwjN{A+J+{=%t~E@N{;!&jbr_RiSYPu~Ab;Fy>QC8nGm}98GQxs z-pQQTXq2``*4XrX9iQmQVPLJXr z@K{%iX_GR($%HsL`uTlFFzN5k)IBg@GUm+UthxGT`2f1Uu;`n3$AH+1fu_z^`;O18 zERM~Y!N=k0QH*bbje||tUht*9;~4r(jxKRHE(G$QCwbKUdB1Co3rgO8is`mo0oSRlXVN}I#l)M+P*sU{NL_*gg(KifJBLz z_R`weg8dmM7N%sR(%hKXv#hf6a9b~q5}H~xvof*o9{vd-W{hG4N=}*vEu5Hk+MR2| z<-kl*sJlFN6Jx&cW0=^i^4`3qu<7Kw+4J*JzPA3WgU{g_0ErPsvkhQwgIiU zKQ|6Z=%0)Cuiu%ukGDnaYRbcTz_y5AuJi)(Ys$zkrb*zvBW z5hN{*?X^|)(hv+g4zwMSrsW!1sw->6?rR&WtFGtvuV6;_Hr%2;!skn@)E?+Kiv1sFv03Em-1^2Uh8m&2A=8ZIk_XyPNRyETNx_CaE*H$1+W3II zcChP2O;ue@b#11m9tY|!Ut8Lo4pQbgbsa7$vv7)al|q({$)1jDjy*$B3l)sDn9@ z5}Vm~eEI%mOPD!ky8FoD-OmKGu`s)|Ik1TBxEL$7)s5|y*&HyYxbblD&bgJbYY*|Z zwz6hPM>zlhKmbWZK~(iu7&SPh+5jr{WP6Ta;v_a?l%#mZKm19;nnP4$={M{IX`V5K zv+2ldt8=m*{L%XC`1Dw+{t*o zWCwb>zw{i2-ms*)>x<9b`j_9sdS{F$PJVQG{>CU~_Tfz&HYJd4;(}{xEs6juDv)UQ!VL^V;9BD~DtIi@?tJS^T}vLG+)!Ws_#1WqosXadA_aySoelln*ibT=IGNCI&KH9)5yPIL0i4uY;;kapMhqRI zZ5`tJwGt7d@uk|*%u=I(w`7okf-w@*+Fg%-7zmfYu7NIGCe(oe|5EW&m242oNN5)N9S<3Hr~Kz<`*VrXF-GSqKZm2oSm>~ zU=r8q!Gr(kx4-dk-hJ=frO&4)C&%UBlq9~XrR9xh&wlIGmk)Gx!GWoHUpVy)&dC0& zUwrVvr=Kk>%wxYsNz~Nb^6HsqzWLINXAd5Lc*y3ahsNGar>}?vnrFOi$-hk(j&nJ6 zc(88M z$>%<~egm6|V&pelSBHH+pB)y{9lg0$ z^p@39oH(>E-`bdIl$Ed#=Et#L;i1-(Z|y&Wp(4yKBVWtlTo?as+!Z? zM<A@^V0`fvH3RO=Z6xU3y*1u*thZe6!z|%oLHKvt*pl0 zl!MJZ*h&%`fJ~h#;1?fU!?j-BSer~1$Tw%3PV^tVaR0g#F(AH( z`JQMt-Wl5V)m7JAn?5))Yg&PqXM47JAm2H=vcNHS+(m_N=3xzPjinhyc~5YWC3(ox zlx@R5Z)|l{4Nkaj%XO`oL2&GsT7yA4d0TGIZhk%9jD#u-1!n7Vc)!Ki-)eGz4>pU$ ztV4WS)S7X2jnycr#ZHm?+752=1^-Xwf@scm;2-%lGqx}5YVJq6)r}R5GuD_bLkg!S zxlO)3{`uIXDrv87tRgY; z2*5cct*8H^x!b>79>rK2Mn;k`xMU2my0NWc;MrvTb=Casb1Q1c#@x_rHO)A1V{cts zZ{y%gvzOm7v3nk2Mf3ddXN?1A_F@OTBX7=LeRpGV&iwJ+)kpJ@*D?OJGJeC1wc&f^ zqS@4f`8Z+v$hpGKdQ^Y)o$DMm@h(`A~BnyY3SwJT<6N{{WZ** zUth*(Y)vyx?%aP53)SVJ-YA-nz77e?BdFSHW)&v$^gT2E$#)C*gd7T+PkwFs!uK(5 z3qM9i*5~o=S5cMC)%TyCz4+q<{AdH4d*CYI15R~a?!a?64Q6%rK89Q|Gzl9+z;!Ks zlb?PEO@?T^BjKy039Q&t-O$#0@~aD@mlke*1ahONF!8OTGTWH#Kbspk?JlL~O~Thm zB4gVVFa;RC@Hcfg9eD#Il2`_g6onJLaTIfFUu9i`879W#;hweW2gWdaCf4T1TJNR?y);~yXUX5^?)DU{l7xumv z`%rLyG%zABlv&_hF`}z#6;AS<>x)HqlsUD^IyJpYu!oBJO+qwLc;E}W0u?hyRFP8Rv0#D-?B|gh_RMen2Zzen(|d zaO`~YI8rUz@#pwPtiFLw*(FWEUBgn?$RFoiGL3GCy?MLX+ez~0?FF*db3o;W|MTT` z`F8;_vI~9Rm>>ht{F^%*WxDdO7O$Hn+7CBN3j4nxJv3#k!h3t`(1kg&M%e{(WapRi zkG)1i6qzIbekc(N!qYxpee+Asps*OD`L~8(-iv2P)_~kGF2k5Fm-WSrX7qVga$Gq% z?Ca|Q6Zcoy-N@4Te%=H>$IGw7op@3G-^Je`6ezursgYzV{ni4L+Xr?Z<2I6hh%Ve~ z%lt${C+hL^e?M0cOF6wlptFU;D0zpW36*ym8lA&$UKkK{lua~DrzeRAkUoh(o(ahy z{C2EEQach%MzaXnWEirK(fhp-@hPplRFyz`Ld5#mM&vm|+4_$cnQFuMB#NCf#WHh= zI)<+^vvScyI*d}VN}7XU><*n)Mk}pf<^HsK)4lA_1sYJM1h5Mh4>N^Jx;TTD6PoIu z?%k|@uibn_h2~gOQ!z&bnXNNM{N%QJ`@=+khdOQW@3YYli7ZDY`}IL$c6QGH_m8oa zTaoZ*_MC`6=zu6NV0qav;xYESx4uYU0WU9gPE~c~Z)~f-yy83k46xjj=gPnA@1={R zo=3JS)0^M@o@bU(A%X#Q46>k^+3>wmDkwW!1Ebjc94u8m1qE%l$-)nwXN&Q(a!iFK ztjqUz$9l$G;g1JYJ}6(r@7`3d%x6vpw27dv_9P4ek7R=%Tjry&iDT_rbTJiVEoH-N zelk9y6VUWY%*;kLpisxQNlkUSxfgD{ z&--OnU}-GZwmk`Ywc;-#_PAMHXC0@Bv1DjK_>)|cYzvukckVVR+WY40XzU*2&dl7{ z@{7Z1|LXiwXnl1PQqbwuH`bKzL`087Ps5DUytxT+Z|Ui&|M2Lat=^0z*kxxVDCF#{ zZY?fmk01T*@A**?>sL{I`tf1&;(f#puj1>)oy`F!ae_^Ti}hdZTlZXLm76q-qpuL+ z8(l38KA)O;%1a&IINoiR{3eES!wmMtza7qP?9lxZ;Qax3=2wPc$wU{+)*pq3IQ-n| z{e?e8irN0kjL1MB@Dad^Ts8as+*ir&ye#hOsN!otj&8|xFPYbMQbxV_AjT;M0;0rd zXr*shaC9^(`9sQYZx7~J;N@3V$eMQHcu5`8smxXs1qiaLAEoL7J@C^*5Mi=>uR9oA4po)^MBD~UgHbPnVme)o0T;7HEh!p z-|R+gZ!;OTPq|ydbs7XIUo??6v)jdS>{7q%)R>3tv{la#zh`^6rN7eRfA`6mk@($t z=Js|p%=f<5anjpwrV_xfWGSV_&Ml+IP`sEGT z9=qyxacAozg>RzF;WxtkN>Ar3dXTPJvRy8Mqz~FVzqkr|X>I!@xF2xFG!fJ_Xj4>j zJVQtK=u|4P(7GYu<;?bS(5vHp=AM;tj!Z|dcuRwbQ(4`Dn8A|BQp?=}=ocUh!%%d4 ztQ{}E0O2CbF|}x-aVx+t$O)nkeT%%%xL*2$V^@b#eD2rwN~nd%XUvSEahr56&)mV{ z%T)ijd_`!b!Ght&%B=4h^56%v@uy@NF`tSTP|OCtF19V*j*c{xUT%Db>OCK0;?z^| zmA3)78=ikUUPLDq94Yj(ER!^iGfeNIxZQ+7X?&-NdXxHzrrJfSw|^OW@Z|P$3Ih{x z`NzYDW|HIx_>6XD;+7tXp2M_D^~LLl%}yv+8YVY`l%P3ywYKnr(foCOe%i65X)4RI zB&WXq(*6r+tAbN=v+O>>Q|jr;s;_y5_IWeSe?HhLQ}Lz0*$r(Mj5xKcDG__VWh(o# zu0$Kqn6brmM_+E+z6#A5WzHA2yeFh)MY2@hNtgB?)lbtg`5kG8%j!Ch4EX+?#@2uA zyWZLK>fOBhV^Um+_p9}{bpkI{+`s!E2Ry*i3%lO7`u-omKUuywV_~`fw1~~wJBsXF zdMK3rjczm28XZd>#_gI%D`hE>bAD=^lRdbUSgm(MU0D5PdAU`@q03Cjg4eklHLw3b z<2yI|bBQLCQNyNcd+h-ai=jiOmnq%X$|ykY8?GDQ|tcS+ee3&bMGqQbmQ#J zCs6mdD#D%%QL@2qJ72C^DkbudAd5OMZC{+a3+I6oFJVD4tBnSBT|OAJt;d_E>Fc)` zAvKO~!C2uq1&)>x*)GN~#!e;s00o(mdv8-Q;0y)f-!m(ucVzIdb*HAs0&$@*58qzW zh;Fk6h)5RQ(a?p?Knhv-b@X3>>fVFZJ#M}E({nms*%pzH+=J}`db^&qR0a_ zD!Q*rGCFCKfp34u9tw|Z!gfj?4$TD54$Xg2eH_d%4Hf7BwZ&gmoE-8{CONI;UCO&f z{H9TW>gGKUCUGK^N_rDWpJ!pa-B^ZfI9OGh*-@j7Oaqw6ECd(50>_+*1x!l92?p5|-t!O9xdv=;h5OB2*$w?-%>fCCAtW#;xvT0 zl!aM(ALD2f2w2?BrB3=YUN+t6VG69!6!PCTW*eO080M-H+Sl#bVE^;+*VfYxQN+sH z(x1b1FM6pcZG9yM&N^Zf5e6+wi9g41;wv+)#!?Zmrvt$(T1wQwFoke6-Cod&4ynpJ z`Ls}E&5*QI~W>6#wb@|)$#UtXM`P*g2O)prWRGQ zY^h|AbrDP$$cOohf7@11NAyX%g@abFLUv{U&aEqKY+OEV@A*AVwmogMoGznEiHO{5 zM?9H%{iZ|;d#Pe60TV;#ZjF$Hvu(t7Amz6ziQn7%F&oG z`J+O!V~h^PLwmu@;`47=Q&Ckn;I{rPafs~qzsn59&HICext9ZUZbDnqe{v<5oi8^3 z;_F2O{`kA1HrLJaOaEc#n7?N&?B%E{x-u!Xacx6)L*fNPJMsPR%w{w@IN%nADFsY{ z;*2TDSm$#zYQ*~hggxXHnw`J>_2qkZ80@vH{(?SSZ&uoylBGJhE z*+f~%zb#pzDRS)Ah26VB0YWac(fcTAcv@?GsacPU?@*XTSZ2ymERoa zPUi#=!`Cz9&tSFe%badN)3HyNZIOb|*G}(Y?6nl>rl-Y2ce6KvCBrYD#~J_#9_BEN zB=Qao;hNjY zDaj8TPgnKcV^Wzl4DM1l>yH1{isR;uB0zrs;#3`*BpZHlKj~wIw26`pR|E9V6V~g3>V5srKik!ZrGY zCRiUDEQHf$PhFoYSp&_5X$VHGHCmoPKwanhchAkg`w@S(N?UtB|DAfew^E>Bv##ak z%iBTTG(hAW#sMqo!OrCzcH^#WQOhu-A$8H_9v^YBT0FDzgxLYOGJF{ zPk#C?#ZBg@hdK4hJ>sYRSE`S3XN#2`I~U!u*9Ad4qvBqZISO%4I~TSkb0=aJVVohmjJ~|r_5zDw4FL?ZxFM_x9wJ*0_e4y={ zUF@yLe;JcTFX$z2@7BG3xqn$w%lRQj)B2OU?e^V|&^9X_?qw!_)^Eghang{}Mwb-u z90Py(sbV}Gw-!HLJ~iWztQV41bYPE*gh-iBs3rSbv#98xQlnE}iXiyGPcT!3`-+Q7 z4xskd_Z>E4syMH1`O+g>zMQRUa>g1L0yR8opDj#rq>R|LTgic<=N!o9!0VWLkoM7! zh6C#_Hr7a}z`0LAz(GHH)5?j|&CN1YrM~Ya^8%IEw{9tEH1=2AuY`jyjjig2(|V6U z=3-HyB89!cHO2(G>TT6lA};^6oG?n3}iEGCMbdvZ=DdckPuI<$UhU5e7A>i+gV-rZT< z%1qAo%Em#B5}^pX;-u(8R=tLZjPl_;`%rQck784+RnR!~x;W z8qii9k90;~%A**S7a@COVksP-NWiV7^Mx>>5&QF8Zh>EZPbOO88%tr1a^*JKNTttt zO75gkWJh0x=qd2wzCCHR5vKPGF+`z0`0tSLm(I0C9+G`A zh1Bt2MwuhKBw0hTT?f*?Ja`O&vxaA7dSP#9qhH|;-Me@s=RDHxy^h;fnOgH;yf5u} zeOIDjn{T+kiT9E$CPIJrw_7p*{04JEY$`k)f3`YJ_*ZY2NW|Or##&DYN6)LW-@KG5 zxn(aYL+yA-K^RA;wiIOTPYFT~i|@r`UTv};yY+y*&`YtK?U#$))EhUTZZ-?W60BX& zW7EkWGHch<5pU46-bhV>JbEaU5%)#uUJRZeq;RP5{Y2OJwe5bo4dtfG;P%Yz#_`}gHcw(rwdqtx$)bdQ5sC%;y*#TnM0`hlsqm|0TGF+x?NMI>bwUyzA zKJ)g4c?rc6vA^qk{tZ>{uO%>q=PQ>aX16!^VO@7k!0xDRgLF|c;-k)(+VbH=)v{2Q z23`6N))r8w*Hn!3n;W}C|3YmN>Ak`K+L|v7J`@MoWLk@H&1Hf!h8ekjB{x}+h8hHR6b<9 zN<7^GCjKT;KJk{ov;7p6u3+tbTP#4}KSv&|Pm%)_aG1|#{h|umRnm~=;>&JeVgFaD zYe^IgW#%n{{LQ6Oz3Fq=!0rUupDr8-#$*p)WY7}LBD%pYWddRSE(jr=rjYs4AV;?< z!AUPR3HYH-gja94pHt5Msa_}pzG`LWM{$a$0>NOVbnRam~k;-$=o`dRXj%?$i~BmDeTF zuQwX%fh40%1g)x*`jok6InE%OKeon`JMxQbLa6K<2EQNMKDP;^zEp=<8+~6h(+Ur{ z+AJ*%%UYsxEJQ!;n*Lg~(#*FEKbOroe_1s}4kBK7WHM3=^Os!a!>wGTBpPak{{G_F zYF?PpUNFfkyd z*4^y(rlgok67fH|rUoB~c0+jWF*BosSo6S!`U^pM8%2A%a~l1@{a4=OU*x78an-Gc z3^RT!^gMq4b-19CRnEU*93l`K6sf{Z)DZKsiOnzWujvPwHGg}Qa-!1jEC|SP1Hpej zspdXbiuqf6)fhhG|AMJrh;|oRa^=iwisF$KV<6H^Esn}Dp-ox`|L?D+{}>tIcdS@X za{#nRc^}}`q9UA?E}!Ns#i|+_n`^!fE`in3tHL2A-kBK6@nUtiP2561W-BM{;%~xA zM-)CRg@*NYzxufe_v-$6c$GGa8f%4pW}D)6h=c5nYeragoKuw1VaCq!@e9>%O_C8M zEa2tE=U7?N6l7!EsXYf5PC z+BTu4RC1uLM}A!dCQ+_mqDk0DsK+_qJ(*b9C3q3j4_C{^hZAQZDq25xk8;)0IuHk9iow$|ioT!O+CWwp$CZvC8ug`G>FcM8vxQ` zG-3AH55nzkHHmIPg6S`LB*|W$j^h1%QB_snzHo8|ct(Bb?&roak zPSUV9?+W~qe`%pIg~gX95IE&f9GBbH2@60)_nWAHyU%4BW!E)>_aItX)4wkMhfJY1 zY%xU@^<6*;@%^(hWvUG@>a!_vzd`q7^h)u#PM?S@ zqioOg$g4A@k(Q8#wFhTD85IgVSeN!?D^?M7X^oWUThy2^)3*Fv3dezuV+jdhS2 z>6vkQHGz^8{ubG@0}!BLh}$tBRW13y>0OMfac*d065ssbk|=LU&N+ayB+1+$MM7Dg zv6e1`@}46>Ur8k0)FFT#mdieup9vTnj8$O-WJbAe2#@l@^kz563i-_T&UfWrUSC zLjf5prpTd06tUt{pH-feMaqInd9?=TL}HY10psIt&Y{W-_hG}cafcJ%w>{z?!sa_f zQCJzWgx^BAK@LpUc)ogbOQ{^JrG<0SL-X~`4M%7qMAm$f4K4)GL2Ew<76kXIHqL;P z(7k)*qZgy|NgrYC(^O4<$!F%NB6IIKjJ4S9>2IBZ^G^S~;OZpv0$p^bmqbcu&G82f zQT>-_WPbgrijqo0l+uF<3m?SGOUSz8*>MAfHkMZsBQD2QMGd*fYUWkNE1gsCE>dhTF9Z~0GW5cNXxWc zW6KHb(<}^jQ9Wq6rWK?K)#+e-ayKO!nxJJSSpv{#YzITXusO=Y;ja3%adUjKfE()^ z25yjoKXSe-)WHjqbf{mAER65 zf+YoC>RPMj<8D*lhhXOvM{d_J^m$@5bYC&^2LzQ>_;f@bQ6%#eWyawu%StHr<+ZVv z4HNP=UvFt_IwrGWmfRdlg#<>~N3eQa>1vfy!&q z<;E>9*jJ7BK`Dnu9w9{FXtQ|@CNBMnlXsW?8b-WR2!^qIVRRv>9wPprL&@eKN$W`= zBYb+q66tNBR4r?&Y5D<#w`mK|lTSr*YnW z%0h#}#J*1@8Q{`;3hzk_@qy=*!N2_b3o_{Zifzy)qvALv5-J&}YfKwyj}e5UAq&q; zMTnR+YyySnE(gq=z^Y>x4l}AG6|nyuW5xe7#-7Zozp%}0TNsv`l0gBlgD-2wY0d*0 z4Q)*pu>Mth27pn0$ipMi-vAQKrRm|d>6OQ|<83r{? zyWs!AI{X>y%>@*dY+f#YqCuA<#Y{bF9@FII|0Sizf21@;Kl{Ik$!P2Ss4e3X$^?2j zem|y^tS=~Sf7Ip2F70x!K>ddo)etT~-7VS@ejk7%NTA4jB_NVb4O9H<5{HIM94$~l zWdN~9zQ3-z1_ms=@L+TJPR`nFs`w^8N|S9{P&DxYSXsgo&x@NvKF0^z60Kk#Fjuw| z(;Z{p-y|nd1+BA z-4zI+TR^%|232KoptEr0tuZaAT8RKxuznmvx=a~v#gdVZ*}%F3mtPw`?85q~m_0dm z%7zyb^DPeu0q|fBtz(> zfR}iPvmy?&0Ra(>}tZ<_b*WwF-Pl5JI|H7AZMKM~(39`uDN#`)c_xiX2ZJQeGJW2lW&o5-Mx;#|Xo z{f3Dpdctt}N*+7quh2WqEim<#}xtv&gJCPb1y+3GRDhZ5(qE z^;Hy(E4E}ZU`Y}O%xBs zkvv-I(m#8Z0pHEWbbx`mv270oPEM}fm`et0SP3!sbTC@G)hV~%r4_5)D3jye&!MjR z!9oK-Y7C}y4FU}ED~`)dBg~lFT|e1YQ*uk7Gz4y{*IKt}6p;LgBE(TFkV#x$JM6~y1? zD^z!|JPeTta36c$_DKMYE}LI8kEgB2QkbRa+GF4{W6}S%O@Aq8%w}yTW%6O`PDnHY{bmkqSPZ5++N>MA(%P+ zj9y~yG=6>lGoS$t`dajKjeAJz^u^>Nj26zp@jY6~CQOcmk5xQ~@?jAfokMSXRU;h9 zTO}@0&>VpdyhD2JKABq8!$22$uV6dm6~)Wt1&<-YvNeZ-Gw@rbH~Fq5k#d}w>5?+- z9CP@Wl0A)jl0gMJ-vgvf%KbdvS{<8CbGFtl>he8@JJLS;wabv!AFQ*3Twahi&|D%o zL%lWw?4i2DS9qt>nx`~qjW3j*D@{uA&Lp1`R@&$ySURH25|uc0FbnMNmQ$y2h?gDW z499oZ4S}&F2@vlIDL&`O4z2>>%PyE|4J8Z5{Yr4l$mQdVg1L^<7|FWzamfO=)4yuV zu38NN$oIZ*pOnJ{=AfR;#sK~V2E&lF&K9`g^BTTLMOSsaAl#O?#tj}?EVIT#Ejto=1t)IwS zKqfZ>AqNOIR)`s-UA;mk{oHX@VC5Y_DO-wStTm<6gl(m}sYYc3iRveyUq~|!KfgfH ziCDesx~+#LqKF>VUacPU#!?tR`V1N5YQX+))yIMwT~Gxa}X?dkEb|rpkz1bBnGlFS*V+9@*iy$&Ej*k1@nO==dnRUr_=U|2pDQfw5?htmD~ z)lUW=U798CH=OmrL=`b7kT&E^Bv%-uG$c}=RLwK++LXd_0-GyuDoNhG+zON zXgG_Rxo{-efJifm#}prs*Y`O+jBxEX$Sb2hCbM8|y(@9?qDAFX&W5`gLCz>ss^?E_ zBxx`##&C5S9 zU&x1wp9u1RlRHYxFw|0^^eV#}l|M1fv=p`K9Jd3t7bx>WX-u zI6ffX6OUB06c}~alA48e1Fp%k_8%$T`s}p#HEyKkO_ zln0P>=jK@UsSw`0*AnA`4%h@iA0(F=a|)m+V?@!d&#&r~%F0NbHYsFOo6d;1r}AE5 zik1`EWQt(9x19yz-MiLk3u1mu1 z;@qdF$VIw!uhw+p+mb(B|LI}bwGX7-WhWw7Bl%FEFDRNUzFvEzn)^)nmG3TZbJerV zMMjbv(rVc{E5=9pTo7V~x>d!lOXxuR^P%BOdK$MX(3Q(7Lit$fG5xkB}7M zRT?(dInGzDs_b0ld=H~rmg&%EL$Wz<*22j@Z77Gw{a;87^|k-PLhote%bS!Vs64`@ z@v%SJRM3~!nIlmcM4*0N#0$Qxw}u1z_X|l9)S?q&WRbQ+?s@d%&jw6TKK!{ulsev! z600Sj0E|yq8qodg@DdVIP-uQhcA>dL$aBI{&ZN7rZ)=uKT>fGKaCr(G?9r+WxKZw7 zxuq>UL&HK+=>$sFVTp?tHh=LAMc)X&9GLKSm2rq+R&<*Ot!}6DK-8be1iC|6_zg^i z)J}+Dj167u%CL_Rg>bYF^(1p?Wy15;@(*)UA_SW+I5~-F<5N!I=78=Y)6lV+N6T`o znbzB2Jz{f&jf#x151pSsynZuPjEc6%XaPIB^? z(Eo2gKKEed&vcUa#7NC``)IvJFz=&w-{lFsTi;=vqf&t+%*XQ;$Uwf-UOHq|HJr9zFPvtc6 zKl34mPSZ`ryf8H6>qsW}tY<<|9m;Q9BlXDCPI=13X3nwsYY57h^pwTrn?W9Fgj(ak zmRniW!D^LjNCs1fh*&oi1sH5<`LmUoy!R zgyLHGa-7?Nw&T+TW__)I3tua=m z$miJE;dB9dAZT43tIfib+zw`l%lm;PA7!_VS8Av0vEVwgmLwn~DRt}JdC9%8UgU$|RWZ9l!SOj^ zBx4$>S7&1ZxoR;7Me1?CH8<|d2}{WL=o0v(^qTS~l zrT?{f@$=%Q|I>?qpbV}Du)?o9WkL`9_mB~=^9Ftj+&xLeq`jZ@T=1A>C;KpjjyEH_cmSw4IO2q{p4*Z1xaT zOeDG1(dQUCb>ffGXX4?2$%8~muDHaJO^;t=nanhAxWO1N4)kzF1I#@g%TVYzu&M#O7kn+CEd zmd~^f6LTt$Dd-jw*M02403mfAc4xD7AXjc!$ppVep#@bT0a!9ja!T`cG>=`IHFQ$XDTGH3 zzh{osQ#s~wSeZ-JnJbEMz@~JdM7dub7XdcB3xYeG`d4yj5TI13G>$dV9l6#E$UF4Q zyITrm#1O#ppIuo67S6}c07_Yy2*ejn+4LASB-7i5ETvU+L`g{0+TqBieVGYB$j|d5 z3E$7`d*eBa)GAO2Xh0BWGY&2`e~z<<9T(MRNi*4oCha6*M_IEdO~SbisT4xTi0W5- zO1w`v3Q7sgXi3#f)9cUril)#|WTkSHSqv}v8r=){C}I;(QGS9cLjut)2q~2-c#Pwx z@yigYHt;6g=Obe&wQBpa>jYO;)9M3r%P@T<<#XSiZ5?R`mtoZq$PoOZWhvQ-9WL;( z@jNweEwE=dQ+%c;>4NDW%?% zYtffBrQ4`0$g9FgGwF@fO;EO-*yrPi@HUk)k@KZR0)i-?7{-BLJwH|V58z~A;x~l! z>yrxc_=5URiQpx-r%NUeJagc2Txyq6?0u7gSWW4*C2Feig~*J;xE;&r7?JM!c459L zkAlh8v@?6ySiHXAODHRIvhQ;Wg)+~8dpLcdyFt6)ED8u~*`u1PN;U`Sn^<$jQ&ZaL zVxc73roN}9ejY7AZ77)kA$12auRux*K;WRbvk=5~g~!B*^U5Z+t|v+U@4WyrRaTBl zb{siRR#n|;eeaUiElXM(o501+c=&z(yKRY-)Jka10;#mt#w|nbF`FD>FQ4PZMv9*? z&45(3L_{!ZstOjZrG03*p_}MOCF-@M{Y-)Z3_FT|Pp!P9hOrMA>8?6NNsw{9-93To z%+PPp^g?15%EF8<(`a(K1V!;tHl^MY7DR&^T$6kMC9EM2j09x`41S}F4ghFPrW75J z_)5BvYo4G5lAbo2Fr+p)eY-27B%&Vx5?bVqI7*_LDP={W&=QnjQ5*6K@)e3r5XOxL zbh<0a06cI?5JfDODEsgzC{@m8<)Z1d6A83cw1YKTpx;=DzQO|*^L&+}meMYu7wccA z--}OGl`^ye95^r%l$=ZbnoxAeqd9D?jXBM=mIuAtCa&A}Z&PES6qB|XC?HZ&s|7i5 z{QPyXb4X0S7Ohg+^VIjDPWMdg{WF@jo{C|GxO%X{a#=%;R-fUl15+dxdL=5SDXD18oe(wB;(X ze)e%>NVC3r(d1e=>H(&~du>Z%di;$Cg?yMP@%ha0vTiEbHHn5CRT{f`6iY49JJQs8 z+)H->qCNGH@lpk~WZy7t>ba&;mCg=%1~2~C+4~3%qjEF4-8k7z-78amq)`)$2hVT) z03l72(nhwu@R+Zh;HgXd21l>?6z7tQo!90z-99y6nZsO?>xJJi3V#&9^lH>)ub~qJ zoLekIXZ%Qig8GNqBPGb9k{6K#_utgwpddor4ydtaz2Fp`@W>J26bH)VdrF2F;vq*@dE)hnq5^jnoRhKo*!84C^bad&|Z7X!<2 zY=#;iJ;XJAozDjRHIw1^5HF;=2g3%TUSwY2=@HNPFIZ*Rt*6;?ny-;#$|Fnt+Fp87 zD^l;5A1!OAlFKDPW44yLuR6auW>ydJ+LyM_`?x}UiUPeG1L{1?Se`t3vK9|E%6swV zO<)D~QA)GSW8tN3h|4WaKq%7KjS?@N;PZev_E}!4fN(ACGBe;yD5W z`9B>kTtQ=s#&jU3sod#&`%Xm#)K`Wv5IQ79|1VZhKGcxv0u3l@(e1OtPPSHdVnAwde1Z4xio6g~5v)K1H`1vo-nYN(mLiEaW(10Q!TmKyoux zhnz;umEJdYF7`pAnT&ZeQOC9T=kZo7A!M`WE;1sJJTgd0H*{C^rprh5d10FiQd{bn6Cw(=%&7loj0Uh zByS+e87wbQw~ez>^b>lKNS4qo#8X~9idN#@HA&`z3BlQ%aG+2i*fge^;_;-AE;O&* zx;9qQd7W~4%N9_u*_t|Sc8VyuO&?Vj^vu-BU>WekLE+MmMFD)~{!~%!&V@nAN+nm4 z%!mFH=d&2Tpl9$9^P{SqY85r+AUAF}%qmXNy)mOCPNvV>3(vxgRCa|0L`qRLr`k{U zE(BAhj}j^N#wZ`WpJ5|mAw$XJux{}I%$+KKX~N0s4Z15%^oCKY~xmd0PSjjqWd^zJ@J|U7PnHSCc zUr~dk;<^&LXk_Cmu#&Y9Ymja8I!ia5oD-Lq`e3#m1t2pNhf@2Qr2xkXk9YjE+1Ew- z*>|9j_t$|`mY#wsAV&-bWmPQfEnb8u^Jju1Vr zOFT-}$H>$|g=YGJu*dXiO5CIvL-MQ4&q+lkU{&XSjua>+&N1K>hW8`uv7)%1v#Moz zhWg&PodlBDmH9I!hpD3VEIFTCShlHD4%L?;P7G#%C23=9l88xn1Nj9qR=9@&q9hMA z0HEIDN6=z%z2-Rs<5YE{{84p)&+uQ@m?CowNC$v|N`!Ut${A|eq>S1VXJ^=j!%d4* z>BnHO0!I4-M$B`+nsegK({lG#ur*6_QGDBj139abnmoR8hzO4ynCJ=+dCITKK9D1? z{#?#5r`2YdbcxlW`^_si3q*sJd84ikF073|Q_nO3)VNoOR9^d%@!jOoU@zYMTWSS{l)B)cEJqS?vsg9l>X~u)8cumF+Ta&ekwUkN zM6Ikl+~5y=9z2g1Js1x5ZNU@P9F<1nUI)AhENE=9gs4sPYSV64iV$J=f3_(-Tx7Au zPdTO1Rtrx}oi5_S*}yPvNvWguqQaDYh5GbbpuD>m-m3WjwS6K^RJYLot-=kZC21K# zmaxu5jNnlv0WvjsG|yobV^w!J%-pUSX>*x1L0JSu`kU%fU%Cj$yCv8&5jvJ*Lpf{= zhGxJ*ikDI~wtoJb^uDG$?!BHbZcMGcbkPFZ@nOrP7t@oiqpG_t8YE&9Oq{g49{T3W zC7>ZJnkU(?3skG+fc%7}Prsg@yrc10qf2pubjevE)1-^6>y3$BcI;Xk_gVI0<7}t) zPYF+(EAP$B?M)gN9Rd=7N(m(#cbl*kd~hVjEM5QOkGx-4(epogh(7JR-}6NKIrjcO z%=5eGy4cyKG>cz-Q>s=XbUySI#VUH3noNOn$MgpeDUaYB8|enRtong`A;GS_BPf0~ zOKlgoCTy3<8cVF|oe%r6wIGc=3i` z;oJL~b#*P)g7vI(ZwpNKmyjZjWt$%yd&vPEL1GC~6Rh8qU$XRaj+A{9dZ|lCXTKsp zb;xZ1%y3~3$GWHG8Of$r!(H}uf?FrRLGUAndlN1yJO^TS;ilJ9z51lu)*EEl3DI2E zpCQoZQtJD`#?^{!!>#u(Ap}-jBiujR>9^(R(|LW;f7&!_t@Q_*<`g z+$op!UMi_W=LtOQid_vY?UWx>b114Q@!RA>g5MWBlR6W$sJRiOC4GllOi$hSSGjF;L17+%yKk8?% z|Bbn?(24_CwuKOUfCPuZ-95MjcXxLU?(QCRaCZv?cNyF*3?4MNyX)huv+ntZ_XFM4 zy{fCbtM=a24uV98E*FU$BhX6WtHk@wBP(hePAV?Bo5c%fn)FHmg413g3vNg3o9Q;D z*3L3vddkDcpKvF_a|cPSKZlND7BPCoGz2k}bdcze=y`DlKIXl}bCmm_h4Yk%RF`kNK?5vAr?(O~ zKJ}SJ&!qAXEq_mi=qE6lG3KNe#eFf|V#l21M2Dra6qR<86(>{Wk-}5}*fkbNMW(`z z&^cX<+>mg#^nBW4k^v;2SzI$hL$TB7_nc={7vKC-x-0*~D}r=&#Igza%*zLAi}4A! zQ+)b-Pl|4f`ThBe$3VExMj=sb=meB#262vRMfE0S%1WwC@#$1FvrHHHVoO3D;gm^; z`Im;9V?d}dQpA$u#lZvs)%>6-<`5799e5(OwWh3_I?0?-!dO|$f@rkTd!k?lmj(UN z<#dzu!xWQbG0HJ*TSxJ`gYZ7YK!r+Etumr5^cd23m{D2@@;EX3AWIUGC{$BAE9Aja zRcW{qRc&!ZKYr07s79kI1%Fed3l7<48VtJYA{V-OEdJ{Ue3ht1Ka6cCrE{qpMl;)p zh9x7E0ftvz7~_OKQcePXUiJ`CJa0x*-hLw}5rWGn|KIPh5hMB ze?{sJ)j-0BY8uKS#M6Ihpcrv0lQz;n@lS#n$3n+m4zgw3V5j-P;c8@d8w1=>bA0H2 z$apLN7&x_HSY^AXzT*};(Vv z3W2zQ*0ye25HDV))SY*q_5ZgU_Se8(&73iS=Xm1Kp@!p`@j2e=i*eDP?R7@XUV9cz z%XH8b6eliXd?RflKbT_TE((}H*HL?@%qtSTRY%!(o{O>NWOO4Up_ss=W#-7TXp(bc z%*pxvIFgM)u&8FFAYmn5kALlRr zy6RZq_8#8wyqG6DEpJ|IV`Sa$S*c`j#Ui1OJ)SOQ3y`rkw<_T1cot(Ux5JOzScue~ zO}Rg^@gh>eh=RHI@^+A)CH5=UzTMO3xCgpcX*M7z#azKmmUK@UiyhN!tC(8gSu$v7J& zju$|lg;JxLDCMC>7TfI5F8}Kc9xLUKs9q-P2`m0b$WlyHHud&Zpc~jJkV>ee_`8Ie z?tsSN z6V#-lc4o{EqzEx;8-b-bY+G3Wvm|l$AE2BLA*A|G;w}&p_IP!z_XvvkbRa{M^CF>M zHIQ0~j1cUq77B>QXZ^S{l{wW|(P$fcZJ)6C$7qQJSR;o&a_((NKEZ$P(EG)G`Y}>N zTk=(sUyPSA(E%UUG}pHEL@*X0qaGB9sE|^EB-0AW-Q;>8_UQ&Ak(#b1Qr@oWJIWUl z3K8u2zeoV>B9f^NQaF3#(=L>wNf5YH!1O3EnK4JeuC8&G{QD1a0lm^7CzME@WtNh= z3uPX1;6JT+lzym9tQf%8tWOCl6id)aw6(Odz5NxQ$yI@(@bLzV(5PBxrXho*DNkX2 z2LmQ8igHA_$`b_-x8wIY=X?sZ7@Ug_;p*Hqs6x4T7pdgCcgQL!mvMxHUklNBp@>X0 zM4UbTL8287H6d8Vf;;VUCZe%@Q1gr#TOJM18Y`77?dJq7 zWyiE6lFxi4(;K)1fFZhWJKBlQMjcSKwxh=_d^tX z)LbyrcWfkcdXvbFG+~QGTu3R!(~%jGAw@c#ocbzlB%8tF5GgR~hF$A$iZi(lz26_V zog#{2MOhP(wLw9o$+de$^aZpq+3#%@UH3TwjB8rB1sEs4g~#lXFL3R_6sJJJA}Vrr zum4wCf0%e!8%x1Sy=}K6g6Xe7eTsA zfaVl}F=C-Y_1_XEYsyDX2oqV3#@)I%Ptc|aZ={U^^JWa=AX>nwVtF7H#&B76$Q_$` zZZ9ni77MLD(?d~@5SDmh6S}1&sEr@acty{tWP?Sw(i)4GQ%63G{R${%9!f9g?!S_H zg@$!hr6FVy|0*pSHe8;aW7}F`>n-*)KOL69Zj^obJ~jp2tvT5d{}8PPOc%n!?DVlR zRjvQj$Bq0@OPfuPBB@&NK*(RO=crqukT(!LI~k5S2C7v8uw5c8D|v>cWBm<`g%73O zdHTZdCslGi{Le|zjd|X?LuY4RHkbS9s1mctnLE}mk=@gJPuS$5iawU3*_Pf0W_U3% zR=$Gh%aC?Yg#14}Ywk>>!Sl_;-QX1&w+kN>K6`~?W%96kGYZTQxlR#IOaoGjFixXw zXk(-{!Ox1sDu*vlo~x>BvQ6vN*j~3F>NxNd-C#L^f~nySz#!8r0hj++hCG&WCHimA zrJI?l>M!_EC+rJjRc=lExKfczK2Ni@&1MzC7lVFHA)Mgp{OHEt7__Z{ER?du!DVO_ zu5aA@XXGl(^TF9~$CQc$ez)q;$~@Ua;1Rq|*BE%OZmt@oVGJNQSTry8^xneWzb5Rn zUowiHtrC5@(@N=9Ea|25CtLp^mXI6AR3dX}vjmnQ_9%{I8Z<(`QP<$*WEv@+P(q5I zw=`|pM4@8SUCuU_l5~`V5LZOsd3*;iXqvo7a zAQ#H`)5Npl&{y{sJf7BpnQmAW&q&wD(iO{LJs)N+t7@f%j4p(mAHrZ=0jqwfd>R*AJ$ zlj%m#5)nDO0Yh%!86KCO() zDP)>1?9DA3j~??%5pyX=hC=xZKza3x&=GFVN76(gM2?p zW-_RF5^Bu|7l<_U@z~`st58V^{#-y#*XwP)m&OwD8xsknAf6(oA5~}dJ)L!Naas0? z8xlJ!&d(tPQuR(=p4AmCOw^Mc+D|k|7ye~VEt}%!#L+5D2wkIQ+)cy=0;XW7WU!7- zsHy3b23`Bgipp97`%OY$#yilXETx7?edG4aEJ`RSf<(T`6&JJQe!ASF!Acpet0O~} z>P&V8=Dr$VW*_|EXk5^5pId(QOX;>c%vj4qOiVxik&;W$vrS_+VZo2gE@P6O7@92+ zGjm8)Kr@PvdnRU4vWs%y&?>V%c8YFQn6~3oeC!aNilMtQb#fT7D>V^cMthayFBR`{ zSI7g#g{%D_?W=I`k~YwAvb`YhUeGae5$nv(&y*Jm810y^h$7^>(nYX->d#h^O14rQ zb7hvqPm#3tCMxRu+b(}-V3@Ia+RjEBMA)&ccUCCgf<@1GqhwL-If`yE+K|~ghCo+qeIJIdoC_vrobO%y!c~J8=h36<6f_} zj3XHWocnqqgLQXIO5Uc-Lz;q@G$AKGkA*Id9u#{xQG_c22U~#fkX|-qXBlAvR z$nSfXQ0G;Vp%U((Y&xD@;y{a9SHONi1katUa}he|V!h^gV0KY3yAr=hWOHnl)bZs! zqA^4{(t~p>#na+AwpVIrktPpSm#osxnvf=0~sj2JI4d`UI zYinGd7fN#F%#oFv8pBl?jH*LRAZnT@7U-%fD_h~cLedAXx;n#iXbF?$+ag3CWQpD% zNCDxbB&Dnn%|vSBkOv)XGk7vKdMiE_BwemVD=P6frq&aP_JW!tlPV?xg=P#T4I0@S z5S)|F)=QRsz^@A8Sh>J7XECa09s?=wvGUR$DUEBh5U_mYi(@%Jj##7D&qfhn!@FdO z>;yVWN}Q(G37c?i%GBiPX&7SND2&(l_?&k(&$@JnV=}l4SNZ5WOZO(43A5t!xc{zd z<)!9oYiJ1iuNH3=3Mcy?-PFlv3EA4%oCz6yw~B&dHk2 z=7D95f`Ko!G{eMB)&LM^WNG47u5WK7QOBRHc6nXo*CTG1iGP|gRkJ%?Cgf|$t`How zN(QR4wio)&-RbKVc#RbwYHRZ(mM>0)A@94Nta7rDLceWO$J;O0=VBORVq&YuX6b3E zId5*}_7fB03fUulmW5JWtJ{ufvDaN(T_r2bEe;MY{B^1N_wV0kd&BsH9%vYlD#*nJ zX>42)g@~guAYRDVK+x^zAucD2m)R{2$c`6?kV1X6+BUo5VK1Pm8Y@AikekTOlKAWU ze6MH+y`-X;hm!NZm*wAl%{{YobIBl(o(h|%%KXIDKXSxXx3=SfDDn<>2LZn)HCY!? z6`P~<8TAAFyg@#9yN~_-ViXmHh==!~ganY(u%QzSfBwd)zuNZWhvj}rlzLJF+qEEX zCoz()j!s=jL1O3XOZ3&u64Jb-B8-NrORCM~^uz=}PDjRhEN)HrCsH((Tsp5~V9Q7d zk#N(GMQo1wc{f_>4x^K$%akIj;^JcEv)=f-v(+}(l^=Z%v_%qel$N@><8g#pnL0YZ zhUCY_%euAce3%2t6$&z~yZOpO);=Pv!$}&8A?_q8)?`lfEuVnN zWm)~4?QjLzlSNaWddW2wm`rH1kyv1Oev8(bKz?9yc0IJ_8-nS{uP$)AVK?DW!WPFJ z8wo2Vr33dV>iYf7(zvBzZ102N6iRO7#!Y4=+Rcw2JIfZF*55f#(d^UZ9?k4jWEa{p%yio+lfxMX+17e#PPI2iTp;A9ExGJfT;A?JGnL?##{h?`MVrm;IL*RHh%WtA*PS0Yo9{ zvf|<)?7{r*##RpIZpZOd6bJC@mML0*-@zZ_Av`|slOi1dd)L^O)`b}e#JTO5%?ceI zV>mV^5^FTmWxo4i9xW3;i=7R!@^?kqZ)CHc@7!~0?M~`0blr@pq6OSlYQ^z8T^(FAzdijf z>F$RZ+bf)vR5LYhczXaDB{{!+>+C+VOUUN-0(8Tp5TrZSy=_e#HbemEWMyP@z%RGE z=Us`ZsfdI*Ij34OxkC5h9p}CXKWg0E+{RP8)*- z;Ab7BF?~p9rsWOcyIM#r1O~y7Q1S2_rVe^AwY~OkZS|(dp002i4;>vduX_!L#0%b} zbiSC#>A0wFchILTWb>)&siEt}sofczEJhEn%=v!4fSb~EEKVAtUdkdzqfd@Nt$!@v zQhl@GUILVlCys!U6eu_jKTbV&^wZgwA-V#YgL%8d3soKs^QvkLUVb_Af!3lS%(cv)O#r>`Pnzc!?X1w>S)VlupU+b&SvwjuzNeUfL$uDg;5tGt%3R4DD?$mL)bF zl3w9fhSISXyIN^j2AUn$IrhyrFZq=+LTCX51$oG?7` zvu9Rbg%l4O3_iGV!c?pe4|a1p3~*}vC0}Ua3ZgoTMzq=$@ID25?eMAiX`*7^z}MtD z6@`nb^Y&i^7-{FgP(Gebw=@Z{ajcqek1mpqi7J+Cj%mSe0(jT}KRsgifu1f}*0= z50{_F_G0wg7#p7=GhJ4uf-;6_Ja}$Ap`@ybeuDpz# z?FoDVb0IaZnERQ_IljFq z1_tY4?d>j2mURrC_xgGCw6wo}|Mv1Zq@U)0_Y@W$wG?_CoQgwziEVEBjYvP2|Es2^ z+wtuso97y5%~MkM2fKop-~Jkxf!AwTgy3t*^O^s{5U^Ht4*c>o+q6Hm*0j?b{nZqr z0+CdU;eD8So3&H`Ffsl-2}<)gS-Lv8R&4ZW-0+wt|5ytii&4*^>hSlgn`=eP9h?oX zeM`qaTfhFnYpu8H8Xu3rtnFb1y-q8{b+#P+wAxLd(kD4H%4A+u|I8ND?@zi}>!=?Q zgXY`L7Yh3icU zQyNpoA@~M5P__}iS8Z468<7SeEKAXh6S|E5?8WwAK@my`rk`8#|wUk};e<5d(MAurT8H!SWSQE|0iE0`ttO6Ux6@y!8 zebf4EisfS+Uil2>jWNak7r(?@)p?-}C znJ0S)i+P2xLaatGDP<99mC3VaXtD62B`hx(q5ze64}lo-HN;>%V9HhV5S%$sYUv|e zp5_^ie~+YaL-`qmYjm(Pu?=(^zxWj>BtdIQ?fTpvtF)HNxwvHqr1NyladOuiYy2sJdJYnY$BUgX3!GffE3ZOcdcksn= z)Q~#?i&{asoa-i~T(o6(@^AKWnz9G{f}1bEd#vKZIa)gB1K z2e}M26=k2-_HB#se)4ETW>)69?(I~fDSRD)be~MLEcQe<4-QzpOH-WJdyot#Dm9QzXj$ z$^_-)T=!<)2ew)4`C3qI4^&_@HSM2)kJ-7H!AT0pgQ|xI?5NbF)!H3}&CU4q#d+ZO zx+;bmk4{Gwp|^yOJuM(HMp9Mxfo zun(G@(|VHcI$J7e1Y{e^_Fs>xN5ws6r7~~}&xG!0hT*v7{dvA1PS$e;hN;uIQV z+vOp#BNiW6&)c4`vzmM(o29djNm(q0SZc(g#vfQ0fx=XPi}d}jgoIF(2j+89W9_53 za+IcsUf2u2D=sMZF%fk|qPq6$QqZC-4S?UQxG?WErV$U7f}2no_%G$+I)=3k!VEn! zYSiKx9;6D6HR1#7)pr7t4&4iG?!t3CH{HsTxIJq$g5LVKBd|)f2 za97NpuE-M&hn$TZ6SzerngtzA6nAq{P%h|zF)h_MvOwB6S4x@o6jwT7(zj7uPVVCW zvvv7XF9oCRTYMpUO=v69qn8!e&2Ka`cWSd%54uu{T9}%&O}#-Qij+N?=VeoN0;$HG zYuU=!#+p}tQ@N@eAk58DwEg83FvX+;Lpx+9l2$o{Q^;}Yb6mL$rgIgGsTA{{9;5p7 zjumof^if){#1fdGqtJm7v9{uNcVF6(L zv*voj!Ndd*swn%NO+1dE$T{v0mxg<o*_a|%N?AZGg?Wo7KQHY2uEe^K!_C~ZUEP0ED z9zSG%jCY)OE&s_^kWx``mXrkALq|=DXlo@GKRf4kIHG15jci`EHQ)>^kacAUb2Q6} z?J*Xjr#bKM?3R_33hNgeX9|!~EC9e-SXgP;*r+yx%UQoCKmB9N!Urnk@MIWj@f-JX z==}VQkF%#4o29|(i>c7}TY$I=)t9~7t4(j;M201tLnz#N2mn_U^_L)GSSBUhy>rrSlrvB|&z1_v^C+6K8T(^rhjzYb zD4~vnkYxuSc4LzRES~+^Z63k z!cBRH*e%+hvgeVkA=*jm)3!KDsIyRes)dd*X3<+vk6EkCI2p)VTzr4~r;h%p1b6kh zB~W2qlgPIs`rV_(h)g6Uziq2Rm2T-J5U(R2GKM3GPxqii*5$l4LZzN)=$(vR{pXLy zjC*VI<8~@ebb z8{s)8ZM@SQSQWYguWoGYv|Ri3+y3-c0dMPDChJ;wiw_VL;&Xrca-a7?UwsA5#>Evd z)W~BxtcH~a_(h=KPMsi!+$gL79JjsxLvv>#($1EGXCm_Jg#XNK;%!y zYB3RdyvNuF`E@5SqSUCHV&QezN@=!ezRLNlsmXeLA@*aGsi9flD<>R<_@kw$BJPOB zV$Xx(QTAF+E}MaLfoBh5{zeBA^_$P4N>mh-hU_jx)m(-x#`6<+*AA_3*xK+dZ3JBQ zV-OF6>T1C-T@Hm@4%5De(Hnw-a9+JRgYBGdJfEtBx89|&BqKg%WA&TWwj=p0!SW7I zXpe~jMyiU?QyLXR4{&ugwx(H`c&-om?b1oO#0kS~+j84FnAD*47va9Y2>%(*=-Elt zONFY1!%Cgg{L!u|m0?oKZjKDzhU{4e5{yXaB5i>no`Ln7;8(8{u{`L0=VDY)E(u=- zN{GG^BCRN9uw+E9gR6|O(-Kk{Q}_*TMy!zt9jUpM89Q$xLKs3D9s*1{BrZ`k=x{r~Rxj{V6j zg-yWCP`o}HbnG)KZ|sK}OkmLm;)r95^ky&9ZgC+_-}~l!|55gNbD=`lbgJ~PwysB2 zi9V7N{$kCrb|f0Io>m(CJ6ycC=zhn_R7!aDMqM;b76ZpdhVfHHKC2Y}wY-C`D|JeO zqTO)X?H-Hbm*bJRu!$@0i}%RTX62WtvquWJz8;Ogb(gZ~)v-UxatdRNLoz6_ws~xB zMy$T;$`oRjRQDBaLw`r|KVlrw>ywHH5{qWW2Wt%i1#NfgJvO*~cZ|F(u&c+baRome zdlUP&ei|43_=Lhv7YyOWkxQx>^d(vT)qw}d4a`cJWY54@nwXfpzah#q&ZzF|=w>w> z^xJpVej%uN_HZc}4cJs^b`Foz6v30R9fW`T_6>aXx4|=2K~hrE?G{fjH-@@rvu>_D z1blbC;dJpQH*9euxIvy3#&lSIz45dxtGOzV+c>j~{Em*6$aT%_r{nk()_M0g>%N`z zm1i8B&VXyQH*`&}kA9dAU9W5ll++68-sb{n+x2G_*&R!B$tPVZX3>o;_8S-R?=Qb& zZeoT?!;2*2<*G8BPglbde}a|u?UVdZPYK~s&kUNkV)shJ<+B#MpU;4^rIQuqJw;t! z^in4CL!!|aGEzN{HJvSi`y+NP0dZ)5YHN#Y%9O5&x)g!HQlv){rA`g0S)PZ zl(yk?_ah>v#eB&Kfgpso4Zo)q?f-i99RtzyfsEr&(Y8=P+_ZBsm}iYKVi33-oB$=- zQI!~!gZhi)aKfboMVAltl_F{b5qbp?;rh4ie2HyFpZd}$5k)=wLshsuDL)paE}N8) zqNyWmHd#h5t?XyBpX8Jz4PS?LM-qUPg0Dk#j@~cALc^W&AgSnBh6KYqbei5z(WXRF z-C@k-K=yFjpGEB|B!X~$cUEZ1BVGq_Y%#I5o(jnlU(k(i3@D`M+)uLo>UmHox1EK{ zl;cH(=7q71V-DL>16elO2N=x9F{{b4*9X@V-{F_R(HFI}p=AQ0wxcAkb=rVOM!qYb zDl@n;hJu3@l#T%6deW!4X{6D7uY9D-!}DeuS69lz@x+5L6;35C_V&V$wTYexfQ7d9cHR}f&a#{loDYx(Mx!c7 zb=2!n1!sz5LTY@~#~G*k0qinrdEWKx;+Y>Um&i%PFgv#z+`&{@?1fhH6=~O3NdoP) zNFx3t8yk?EdNePuAkBD7v~bcDjs|4?C@qN+pYnM+8oK-#2gCR{E|@(!Rv}6ldLrG) z|2)eU@*s_MxsC>@sC4);VW8fO2L2AmkeZ!k^Eh4B@Y~ygLF0FSw!1es^DKM*XOQ~w zv3ftafXXjh{PxS7xY}@Z!w{-x zHlkD^i``@(oqe9VZnNVl(M!jBA@WL@b}5yXHb+bHM&s!wt&F%b1n)4zQ)Bv$ zFw0SAcc1vNe{7775J^#1cGl;-yTNBVE;I<{84Z{FIP=)Ce@r+eT%s9OK8%(5XaL>C zS;bk^#dNWR$Okz!klD@6ySLaZBQG$T~6fC#cL>N?W2=&|Hc5uW`5>}mG0T3ITL&huJgNhPOt#!5eP z($s_yHdR}VP>E&`EEx4)H`2gSAHhwcy+HH0!L5nv3OhSv!5tH!4q7hl$rv~Ex(ay+ z1I`{vV?;y%(W*1oQ2@fOrFQq?d|2K1X1pHBRxRJ!Dd_$Kazn*`HJ@xEVKmOHde(4) zl6jPAv?-opuop2hgz=l(QgEANPnViMS>mCT94oP<|0$0sMCh%Zq4VD}4v{^8qfG;} z4_BTpLB%B)oR4|gvdygtl=g&|aa-1X9cO<*rt^l~;ZU+-nxoX&+o6k@7a)ipPv{5= z>4BSlZDeLu?;!+TvP247rs78EZ!GEQ>dHD0!)kDvu}pcEJ6Qb^2mlwC(@;`s+aDEC zD^O8Vreb1CmzhP{IUoM~+g!bNdfXlF>gq})+j@k5G@W;Eeaigt%aumIcuclM|&pmsMY1=r7v?;PV4lC zOySriW2>_NZVRuct@x#qON+bSIX;xd!(ClXCnu||;#?(0fZM$%qc`}Uk@Y|X$FCoC zQ+UT$PkRM_V_{Yz$&lojmt{nF@170v* zvhe*XikM(%k{8Ilb!PcLuQIP2Hep)M(anq`+!@3@ zj2qDa<2+F8XkYy#+x~MNH__XL&-F(0eOBi~)MkziZa-%Pd8U%vbsD!|oMYH7rHej> zj32DDLyJ4l1oGk$Cu~5^R(!H?bR_$%bJ3HZ2R7o`Op12Kd18P~76Ukyjnkc-id++P zCS=8W?ldD$F?{QHr*?O9&#B?r_*=Ceg7>T-q2aD!PX*T4gdyn9?#v$8ITKYC)i>Q9 z46j-}X`JA=1uDx0orwIc^CW&jHdcFmS#D1)4bNrgWT{~h`bNI1BB^=WmU;J4PP&OX z={yM?I8diHIN!H%qCU>wR9f$0tmRYf(go@8cqT8fL7;7s+a z*{b+IEr8zeuh@S_SEjjUo0A@?J`_;IVUEmv9immL6|j^-IykI8+B&_u20wf1^8E-y z62a|?-_no*T$Xe{zrlAu+|Cl__^NtMSl(2=JNUnKWmi>sXjJ*XR9L?Kb9gh^dSG?+ zEW98NI6jw%qj|pzc$d}tP@3R+`k6>67yO9(No@%ml~C}BI^Y%!>6P&(_REYWk@#Ut z+Hc<%uIZwzVWRf;H_CNUk;~xs91qR6t5HkvYsaE*}n_S|lG94e)e`><%MXLE}*{tzsn2U2II$j%MrCCf+|TS*>@A$}~PX zHbMB!nIW}sY-|i#8@x4nch{_o9p8(N`w%oSJ-!iOq$xj{@#)-FFn)g`L$%}1c2EXa zh}65M=QDql&qHnWM9c(6SzPx6M3wpNtfI#EGG0~XU}u;}rm|A8`o)OH-aK-?rONgg zFPXL1RV|Li)#Fddmyn%-L(q9IgC3 z%*+7~7s?gZht!IMe$P3Y7YUaULeE&)y6r!v-T7$`ruc|@9zg?8g&vxnx2wpC_De7S zvU!lrz>Vv3iIP> z^N5EHE>+x}u6(mQBG9=R&~LL(LFPo3o6r5N6X-ZRajyY(JA|08 ziuS0Ca9?8E8hI>*0J@=<5%k7j0Zc^>$L$=i-9$;}N#DEva)Ks_BRQZ(MruyraX-Pa z0n=Y?A2d+oZc0?3qCJ^D!$?v7+JPvM@vKEQ@`R|>?aByEce>V3&VcELCs?+uR#PE9 zUURgUBfrUE;p-^lCX;!`CTU2d2kW0T*ojW4~u5gIwyPWc>q|avw}8 zKrDuMt<~=MYwq*11822f#Qu$ww}i&t{`<*zl7&34$KFZOMu6{h^P2gtdEVnwf6=eP zv=XD(tEhlya|`z&!o03Z*Ci4KS6o!e5m3%wsd7WHjA#@AN`!*kK_Xzo1 z^O_HUQ`?MQBkz}Wxx;|Dyqdt?keK*=4MbSnTKh284dq)C4)YA9G7Ma4DTSp)X^Cao zHjq)!8qi_V?SG4N^uy*TA(--ZHQ3gdS_7Zot+TzY0=A^VrQsUM1rrL9o@QipVv<10 zL#f@jt*GI8>7VL1OB;M9PD%>O;<7SqG$uMiwe*DSxiy_)fY|)JUOA6P!QE+hS2-IwCEW25?aXGiVGx1OMbU4IyBYG{LHUH*yzh-zbYL?EvZx|n1W&T%UE`7K*} z)q+W>E8=$;2Z#li*YGh#Gj%}MP=zH@3!K!@Q$}%)>vSo?luth2{tTPYE z7g4*E@~@}-Ua#C9M;#wLXAFiC5T~0d5=IDuMs&`Ij(tG5Ww={PxARvxDKXyBq>{{o z*G?`>HZe;wY7Qaiq0BFc+!tZ#$rp>wMjt(tD4xVCj%L(+>J>F9PO^);Q|GY=|0=oN zoD5EgFa;8)Gw<*2sVB$#J}n!%xw@}DR=_tnUn!X=jgnN^{nlw~be5YRNo=Jqx~F3% zj`G0-zZ^Ndzn{Nu$1ZMUrewJN+X`#?u@Mlk%Qss(mCFqt)<>=L*rWhw;V|DPM>Q|B!!og&6=fts4h@$Kf9IRcu0$}NOmQvBZE^c@)e z5{IIqXtg%UpLY(oWz)KX>65L~^K`Yed2)&91UhX$=k{N3o%@HJVQ+ZCPm#qy#NR3$ z%2!&w>~1O*y8T=<#NIk}=yGShP`^6Hp#WC?T^n@npiG#uIu@KxB49Z@8`C&1&R%?BW zq4@gn0ZsEPc9U_I5j7>$h#1yTh)j?iG6;=Ne?sn~V(ge(qQqQ)ENAl4|{ zM#$+{!kMeka1`Q^7jQXxH+6SuN^R;wEa-eN`BC|vSr;_FV&E5%r1$o^NHFGN==9on z(~pQhX=ae*rOq4I1+4gt?X~;~bq7A0`p6`si1x751}#rZSP^Ps$g$_Y)tmK+`89?S zKnaM^jtwms+bhzbuP3c$N4DQncyJ)_uoU21vH-oYGLgb~=050y_VH|J%{uLQ{Bmbj z(X5wvMjB#&JwI5X?dxeVeHz?|inun~&@_k51v{iM%xIG#B2xTSpOx1%RANL&$)T%~ ztjwia>aJAyTU~cr;ZXmbQ4%w~Z~n=J!LXG32o;r0ZRU%6gwmt5TnvX!aJy0D)C=9+ z?t8Z`+^qTkDrF9Tt~gzvuQdA%5B+AZ$x_w9e!WZB(?|RF+KadEf416ccN+a|rM`Y_ zIEDsaFEeLfA8BS`wccf0LyrLBtXY|rU8qDhNYCpA2RjW#^=~0Y%Fv%`4#i?YlvB}R zjMb7Gw8SMV6L7oSqS4(^P*OoTwD3wVJTUtn|Mn!yzNg?Fa2Qb|@UE~1>kvQKK;)$J zLq$d9GC#Nt*}|a1lXGGF#=(qPsKa$It1=T`r*s1M{n0^CO)Zm{{_ErY)vn>|#i4%I z{g!Ww@oq@><3=}|l~xqR)eTD$+grC`AQZOcTj2&wbk8UMu%vjI=-v5}}KyFH(v z&1FquFmY-4|2g>bgwX-Ai*ztIIPyEXGDQI!6L| z9jABC*kzWb?=Ks9uRbCo4Aj)vq)x5X-4^9yiMhq*7l79G;L!QY8rpPD1!eAt z;&gn`)Ee5F_weh_uDC{r$s2@4xtp~t;OUmRq^wxN-?{)#;RfgenpC|*3K)vSdmoTEv=Qv9P@7tE?uA5tC!9HAEI2F9JoYR$j zHZ+bC{YmiX9;3Kxv_e;SuqD(qREve!(-nJxJxc>oy_dIY9FJ(C5z(ik&oR`oKkEL) zu}xTnLgXkI!clUvb(7M$Z4M_m|L67euni5002X>OyGCeZHvMy133AT3s4EyD@8X!y zr?L`DR64@cR8a7-CAC6!a5qVxp_&q#&CWqf+8b*8`j83bS8VLmRhJLY#OZZr6Md~k zgsyaRBj>h*VuPjrb5B?N-kx#L=!$+s!DFG_L#RY8GoZWxc=h z`F^Fr^7;Mo?fnyA10KTOfV1ZK%=j}+Ev^@Y7hs>DqsWOZy=bTB_i!LtOiE3Cp35i1 z=kKfj?hkW!wtn}4|DpaBc91O14!#LAb#l4frX%^JtyThTT||5~rm5JBI`tm! zxh?DdbtX%ZqpHP?HtDYpH)r^zVA5r zKNDM9Yddms=uJ%1dtST^V^0{0E&7>yT)Z)yo;v5!t(G}YRULEL>-l^DnU@Ggo*FmY z6SA4y5Y}&g@crRR~<`L$)hu zww`KEpk&mK4*D_bjaO<)l3PFF&wa-BIG)As_!_Qb>d6nl!<|~T(UN1GPYHLwI$eza`2CIbyku;NC;v9TwtoC7b2k1c~EUU+qvA%6t z*RaF0;B7$p9MuD>D56ZlUNgeu8(5$}#07ov_c~+&!dd8yI~8eLVf}c}!}_pTv%30c z?xw;A+uZD5mic{U9r$u`e?z;&+0@5p|7|V8UdP$b|EmJx$>sl}?W=;~3c7Cb;3PN! z2KPYF!6Cuj9fC`6ceg=m72tMM#$;HjgW-*H!9O(P8-(Fog zP@1+&qi%X=+}PZdS=|}RF{V5`t^IHa4C= zZO0c!WuO1P9AIj3-mO9_Z%#m^`$nwdRa#ztZG0aqOtVx0p}`Iob-ynOR;IB;O$Ap} zR%(NLlQP!MnwpN>qr@?Y^_m?hrbY+_EV#OeIOjP@_*gbW* zhB&nC0`^j_^Oe{rT<3}(O~D^2F-NP|+GJORNx(w;ZlPVE;55>CYSVj=zJo1NfSCg( z>8K`#5KypK8j5!`$Myvv+3mBEdi)gq4!(Y}RjXdV>iN~Y3uCtSm38NFV<|;Zi-5A= z28(YXEZd50=rvKggJW49p@evbmRe{$tRkClMKjeTF#j-+fecE(aWnH3gZQIY-p6^^ zIwB;gtlNTBAZpVyfYxks2oTnruG5WY5J@?tGD^UQq9Vq8Tp$#u+3P3~*JIkf_p@qJ zKG@5bmFE${M=CVAAw~`n(D2Mb`pS~3#N^YhJxRzMeG8J+4mivU_J1O9+N;rbJM{$ zKtpCiMr3eJ?0%x~I7rzNYe@Fx)YiMbfXhQ+kV4jUG9~?fvuk~ELpMep>TRj}2G8u` zuIpK!n4X&oUfmLvfUly(?K9vC(_CEYoU_#8pWxej_h44bxi@H6XCyp)Y^QI>*6X}G zlmGPk3WX;kLSZRBybVII*3r-iH7nTS21fow;neZ9taWY59|qwG1Ye=SNf#o8NupGD zGxqgd`Eaj1bJ?>34@ys?i{HJTZ+#`b)Ow{w1|F~Bhd+0)GJgpj<9ooU`)x8!EQxZN1sElXN%zG~X;@9Z9H z9fiL?p>J&Ah=OpiHg*4nFs^w`=cO*m4LZ{FwmQE@W&1qRy$!xv`o7q$gp#S481LJy zab=v1veEFp+<&@7apdLUnaV1AUT3rQzPfnP^SV4Xuc!HQ(n{3ozCL9{l8G;nXQi*S zvkfcGS-B2$u;P}>n@?1aBC$Wa8l~}l=}XYWVxS+tQa2}tolcT%2Uh%xZ2tou-Slyu zoo5*}Fa4FNAz?juzlD?7AFq@bEi$mMCjYkg%&Di7m7RSn?)$z^;7D_LU{+i}>D=JD zH9yt2)6o!lln?B6U+v^%>~PjqG8EF*9 z*ZQVnU&nrNxn8I0y`!5&cU_$>RsCWQNuuC>9QXZbHNX+x&pwq7NvP4fT#D+cuCbK~ zUal?vQV)~EuW3)8kv)I)zF-MHAMU0l0p3k_KPUbi9CW~a1aKuJCUTQHXtTsEEljUA zKjRnQ>MQ2K^T2QT6Q#xFRle(baN7GgUGR0^7W`5D!Oq@J zI-Boat)w{9SzPkbTb@AN|%3$zr!ODwo{ry9oWvY1Q_$YW2#v|{-{BszB4*^7ICNgz$$F#7FSc*>Bdh#)mt-m9E!BW8QA*Sf0gaD6FJjbU#*@X4Rtqg zu;MtKsZ5z)Gb2Flp9z8BckJ0?N4taSrQ2~g1MjsnA(fK!=~DtLipFHr>^vsyTV|~S@o}xyFE=<85e|60XxeSU z-5V`;3gmFK6Js&hO0GjOGju3){4nxX;hRAA8!Qb7k%W-!O$l&{lr=ShH?`3D^mI$8 zZr!r)FVEO=E~fJo22bz;|7v*t4BL*r)#{I0C18&%O~YqQ(PPdK$w64bw>5Bs2?daa zSo}I&d+}fIM90vTN^*(Efu0OowxUEAXr`;&oL3UK=c-$Mnp(Z|zn>-$%#N-GqlE=k zvX3wiRvW{UNU`+RSHWRQwX_4?mVX(p+CGJnJw@j+GFsD_k81FyGvoNG0FJqhm|?In z=*!Z0Y!p%-@6lba$ET(RYHqi>GC$Q|xc<4=A^aZTbIfp=RbA@2b?dsV7AF68bX-_n zrQj^&pPb1C8nQeF7-dr;d_F-Kat-)R8j(71l5>*h&=p1VYL+`3KO(y@Z*ZNC= zTKDsZq=TllEcGM?!x~avgPo(bzO|95Ajyk=4KBk6_q_J#rdC6Fo0>E-05;PdT6x%o zbM9>9O6#L`vj>Od`a@0D?IL1eIsUSTAHBiQH%eALL@Kx=u@1y+YAVDekMi6ZJ zmNMf|iuWe?ystdnb@va-{!t<$>T7s*Pq>Q7Z;BuiPl|ESmvwH6szW_*$Ki|hgb3HI zP483rj^c{}U0VBFG!)!lg5c{-0q_>TG`+a$0QpAu1YKrV(;97;K7}9Q<2n*u=Cv3w zovrKJsHt&`Toy9Wn8&K{Cy366is4Bi94X70Dn_AiUD79C8OF+mq=~xJipQE&=-Gr! z?!PK=DrdUD_@cbaYj7lQP1LSG>EdrSClmh;DYcAz3@hayr!#g?g8lIN2Our-&;YFf>*MVv8731Z%I-coBr-xZ;1A{}hi@Y_Mdud_Kbs$F8=7;;-%wp?ZA!m{R_+P?` zvW>4D(nlymcgp!{>*FE##b_-&7A2F%rA<-?vU#>_9KIJ^eP%c&)?H=}fXgBW2 zlJr8oz{@+n|Gpe(MBsG75s$k8+^oSxqUYdcC40(?41s{jnQ?Jm7k9049!h_rjAyJw zT=Mf~VfdWO^!CbJqGGnq18veAKN2@Rm&9xAfD)>7mM_#DvzRi4t0pucd&M?@W)ccDrKmb0Naj!6WJ6Q8!rxD_b4% zbHOoxQVt=xr;O}4DeQPP^;LTh( zw=pVHs+kcnxncQW5d!+Ku7CKOMvsQoGdYjRbw>9hKp?xbvs-c2HnNFs#DLvo_`R82 zW+s=3)$Hh)b;&)BT}^@ulwPG0#9y(jKHP zgQD&mMiO){ z9tFGJi9<$6$+Ri1hG+BBF~Bh>g7=Ib=N+k#E~;D&sqQ%PdvdM@HqVy^&*Zh28%taM zo8xw*Fmf@VNGp+JUyMbfHKj1J5^Id?`j^1Cl0fHZC9O|Ks>{3S42^UX=%qGV+Soqy zW(kK1Fns%%^@=GorB0L~$hrhllvX0wb2E6miGY=ALK`_^+gnAD6+SJ_ zLaPKzOZd&uNi#qZz(lhqOkfYn8`@DMihQLn{DaTc8<`CEskG2c^mhkIl#^!gwbgPS zq^DBoRLlod0>2+x9ruSWGW^sN@e!!8AzglsraLc-f@QGEVA4w50m zY1<*-JUdB}|bvX~*$e z?1sqAV(GONG)_`7chbP)e8M%v@AH3o*Tpl5C6lCu5a=b=2_kJPVPXpRS+cs5;pnQ}R>jamV9XLLY3)y!|$}R|pAz#ptmjGq;Dz}y|qV+KOWd855 zKSJRZipCh2TF!f^;>~2K<)tjNTy&%NDWB{IFBl)!5%w8$^`gf8w|lFXzr|sC0zCLh zPHKVhk+ggjn$egvU}O4+WmL(*Hz)=lZc6%xn%o3t3uYb@?$V!RDxjbS)k}m^X|DC! zl6`BGfWao@OjukbPm)+StMOy3!>srH%c-nNVl4qm&hM5#&R95Fq<6Ku`xF1+0iOTC z!ozd7c%6V&52|&$Y>vw>H(#&LpW$2U1ZMEC5BQ#rez}c3I7W9noTDx2KVx!=sEbTy za5@j2NsjC3J-+@7*8AD3W2YgkRGj{{!(F9pLQ?zAb#J zb6jAU?0)S2fOQQU>3Kh zx$Re)2ztM9)|<(_AE=XzPaoG(eEGURJEv67T8Dss_JJeB?Hkv%?Hq#UtdgwQZQVpX z;m*z7vROE`;TY+C8Qt{~%n3m=zk?rkQ|I!hKfwldr|>*^vJ!}`h1TVw{ayYtjVB3Z z3RF6l%i>S{T`-%OF%H`ML{`{CH%bMvrhpYMBr8oJfg1cCk}w4TLi%Cy5%QIj;=f`G znMg=zJ0rZp_!+C#@lZL0A_d}Sl>2$FX!wz*zlB5O9v};O?RwmrBB7vDT-MVpRi!nx z|1m#z*q%ewu}NfqC{eE$M*oj%mg%3k%y9IZu*KDsok_9rmUJ(F^fTFVLPl4P&;$=z z(}fZIy<_3@O+V$Ll0_&P=fz67IxrE&5dxZ#<66n8@i$XztQ<-hk`>Ld2oqy0Aj2Yx^jc^@X}~{3 zJ8SJTIPR7{lZf8B%~^_}9a+qj{{+e25oFTpfxP-%(F~h zt^2P0@Zo&f^N`|k>KFTW#<_O?U(rkZJ}nsN0OUyURodh=QZ!r|R z5iyaDRn(}nsW1&Lii();M9tPJShLs=%_EdQ^4eq#VLI!ssPhUfI}itIa)C!jqekl?Y`NwnwWxOwRIfUTq_G z>o7GFHFOU7Pvxy+k5JHpCnj64E#`40QD(M<1*TWf0*UZjN&u8c0JB%(7$ zW#N58ud(fZ5SW^$pqrmO1Ru?5M9rrLy1Bj1;vSMYShqWU*6Q>789>!~7g84(1J|2# zRK@@4F$rhwgSHm$Mr28fMJR%hfR^Q@M%drSCV2h;b*FjLn34607f4l6;t~(E1sEY?* z8^p{AXcq*}|HuW4HPI3T>`9f?xTT(2vbCx()Jt%Sn&4sXLN2}s5*{S=Yy#BNO(d;r(tviO0rlL&>t!3;Z`w3(XTAm0%vg zy;I~9Pc+2RU0^T$0_7@^@e!g5lmjqC95Mek+^Q1Ga%L7iiO5QW-1>*UW%K8VM?k|5 zz6?X`9Q}a5U)>Ah=`ri?R;@$>(QrdHXVB%P?J|}8HYl)IC)Sgr4hQ@|!}8hwUYBmm?&YIUYyiD_{o3~QT;)W!0RZ=?VVzW|RZoQjr zb?c8O4|N57n_TQP@2=_ce%!4hIh`EJNu@O8t^|hoDkjN~Z<|{a*yACWu-Sar^gwwoaW&^|CbZgUj207zMEt%&vzT3umSFE zx0RPs@nw^jk$4-U$acR@Y;LmJ+%NreHeaB2GT)9U0Q==1ZR(}Xt_|+YmZvmrFly3k zfp(-!n0%sVJDaZ_0rcw3IQ%>Tr=Yq>E-5AS{b%bRa*L#BQ-$F~hJv0}VplF!sP<1d zjdA|>aOVDQT%tEMm5>^8sG7_PAF%3|=+1-A&FGuIzWp?aUlENLPbtRR^QryikG2o_ zD%Unm${V17MoHz7(0tdf?fJgsA}qgz1%9H_YHo5X>MNBbO>ut@LyL-Bf|X6W6kI8c zv6_#N%g>9eF! zl#rK)*Y~5dIL=6#kw#ucuWi27V&$LmusjaNm|2IF!Og_A+si|JbobpGR<&9b$*5(l z68a+1H3{HtU16!^s>wZnNd4_(GuwT8bm2Lrr+P4OuToj{BA5~Qhte;$VU)yPqm#q^0dyNiQUu)MxAKR4H8 z<4=9peWI^Mqt|+`cgSso1145GxWWCM(GfCQM9evd>ncA-4GlB$WTJP|?@pWQa=jmE zGyVDFkb%mhnF@wnEG@k%Ki&3R#KS6AUAMo@uX{fD9zNNhRp?;fK2{1o2gc&`Z$S%! zV!Od9%?7SoL3zW_p#+{jU5L$tcN*j-6lKOpR0jJe&h3u?QW{Ls>nvWIUWbA_oc{ES;vi~Q#-d5!019hptF}>_6C6RP?N@_psX!%mC8@??d)RcBQ zd@#_5jppi^rwYxQ!%7QYBg&iFI2byd!4R`Zzu_j?GJJ7O?YcCicCEdORQ%r)Oavok zX0!fjA=Vk8%{RD4=tCL4@8WaE;}trdUVEoZbaja>uIpjMnf8z6ww_Zg{3J>hl%eok zd;vg>u-Wo2;(&We7uTOueWnE5*9_Y$bB8Aur)ghL)H|Ba0$tUlmP?R~9Bw)|@QxFj z9tlnM@NCX=>oD)HR~j1|s|D`!`1a3xzx!4-RH&%OB*ZyE_9q*i#*ee(ljS8QV&_5r zC}PRQBwp2QD_6-6$$t%>F?D_#*A8ftc6qf{cwNa#R}H`H{xt}UA{KBz%f}+=Xz!}< zl5p_>UR&!l?>@yUx4Lty&J7h77uvugTDXk*9n1i{;DzlDyE zqH%a&%BE=sFQ}d+ajo3yeH^@&CsnpTR}nhi6wwqvy=n5Y80b4F!NBcy-&LUU^|jxC z>mC96YPw6U%*_3F8|K27osy)KKZHsib#+Wu=(_AL;?wsIHrCSD-rOC)d|~mP@08D( zHrh!h;!WWu<*rDQq(3A`$>3nfJc#i@$HCzDskD02%*OHNX>PW%%@2)A4f&uZqFHqZrNam{wgX&UB2YdBJS(>?LTHgFh={04v}Z zTX1XXbL72*4!bRoDYW|_TzWXtCWXW3By_!a@oAK%%Ju_&ZJ-Oot;~TkFJjAsOvY4YHrAW z;Q*k!PjY+UqwNpA9I${@FM03AeS)3LY7n-Q9Z_3d$%I)#Q=Cp=Y|-`GD9w+X)liLl zqV`I{-brwbkzEb!64@#%L|@*6xnd2@$mky+b^@^S!OH{?3JE=BugDcU~hj8JipJ;rMF@)Rj?IE(;L;ErVIk_qFQJ*Qg(l?1*HeKPduv0uh>uIfg;6)9^)sU(wtQXvV-$R-f`y1oQrC2Ts> zz+C(ix_h>i*or^G1N}Z?ZKzoCIR1|@ZbGII)JAy4O-4&e)TgfUjb4NTs^CVsid!JE zVW$#@U+#=-qFcY$_xVKhYduVakwIfid}O-V@`nTOxE^Nh%B zo^4L>A*aprEHT>d*FCI$XbHZhFe$^v+8_t$V<=>Cvz|%KEr-ASU=aS_! z`Ml37lQ*vm3GHu44+CDWL$VN_o@sQuP7XLah)XTxvRJM`yj3s7w%*`eU-z@++9CsQ z8GX+1aAa{xfL6e1BU0^C-m`62-En0yr&rtS4^?Z?F>36ARa^q*rH!_Ob+lEVm#cd4 z)#T-$bS{sb`JugcfbuC1h>3;<55pVu@8XF3dLFw$CK@^xeW}C9)V*zGOC-w`{R_q` z!7-~0pDNsj@D5UHaLQlQL60h(`mw-X{OJ@Ryf30~IW)M>_fU4J4f7RGmQ|)0D&OQh zfm;A(#n*~5^=`VYtjyD<^QVWoVvL_AThSIqp524=SAmMIBC)~|r(|o@DSh7rgB8LK zM59tes$k{ZB$&VqQ${re=nBIMISk#!n?0$NI!ICREA2x0{HLno{=ZgLOb(o4?RKxW zun*NoNg^|wS<8opG2-Q|#wp%sfeHX(`RJX((Gdj7@o3B~w}=tL9=OjIbD+<>Z!BXch@wSOE(w8%NC5 z?))ZW&akCbnZp!wzoqkumo-=@CB|fFu~4U<#3KC=m4sg%s2To>Mnpgz>qhmR&_`0b zaUM8wgp@D-C;EY4;RE44obnL9{CfhQv3wef`Ocjr{dkbc?zVkR6ei1~tJzbHC_@rz z_p&ZyH?!??T7=3s8rOlHh-6wC&IKAcM`Rg{L+W8o-;t}k07o|Ff;cjIGIMjKf)Muf ztcKUP8N;usoHsA4)gItXVtIj^H^GkuSR+EhqN%h7-^=aF^ycPFYi-B-!Iqn^gb7L$ z&dt^EZmUVrE|*Vm@?8LkA&Y-WxRKda()6#_6Gz*Q`tA3wtfN6xPr;37LWLcrqYvUC z$HRZCJlQ^X!3WBWsvL=h`&3Rp@Af$>hN^ZJWmmi3CojF81=j*@fAyG8SuZc$nN?oO`x; zwuy5B_YMw5baPP;i!X`cXWnJZ@29Om2mK)2{6_WuILSCQGW5n zK8F|Ubd8i5kRoFM)mJDc?Q!P)ajB3vkji=8r(&C7Z;wGOy2cVERV}5W(NLng`>wXn z9)T9l9G}F;hn__r$iXEQn5Qc1CyT)rqXo4r3@XM$PE}K;C&F?Xfd>^q5Mxgcs7M43sD&$!1ZgYH^fGNyt~? z$IIT2z~$r>1Xvp`8~s6uLXDT`h)YK6E)Zo@paVL<6eU2J@=y`3(2v6z?~B5^O#Sang4iT49Dt!XST993EIMtXAg;2Tirj?L`|)$+U0Sm_`90=FV! zvcCE+O60JpZNi%Nft%B4s)6Kmnn2OIQh$D{0?L9a3!HFkrOjUxtiR-VuIlIv#?TMd z!+wKGSnEN59WQibf9>P>>`t9uo_m?~FOqzm?&^WxUpL&$`)etN&8D+~CW|6#YD%iv ztVVrzEAbc?yaDNQOd70&)Bn_R zo6g`_iR7QF_S)SjJCZ~6u>GX2z4>_G4!-OD$R75&yUPfMK>z1AN7uZI&fHeZZR#5r zZ|~rEUPBRaGj%iV*A01efPkA-?<_=FW+KO0P3;5xvE4j2j%5}W4*L<^{oKxxU)BgS zWiP{7N#b9Z-l4LruQP(2aJtf1qZx+MgHs6_Q{uTsFpi46lUUsoPauHgOZa@WrD26v zp;Z?)_Xsa2^<+$Xb(*cJtffe)Dz70T46cg7D=Djv9!Lpr0eVOTMrh*3<$uv-$jDzXrkA0~Q{HgDj_T4gv?$gzO5jC197L=4_ z)i)8TlXFm>Q67t_;?OaEqZ55P`Ng&D@kEw++6IbXF?SZPElfN;Nf@7QMBKWJrM#v{ z1+u{g>%7f=zt=!JO=K11#z5p{mku&-K5E9UZb?Al!jYj#^)zc_}mcS)NRSJgK4_5vT|GO`>|vT zJ{5d^rSZ9AKZE>O&5a(kU$mD3+YU2v6ZJIgMse`A@A~nwyzaAJZOh84zxrb#T{E8@ z@ti`lU!b@zpq1-+-88CayO-xvw0Rrf<-V@9Mm;>po;EdPa7V13Z*N3(c6!%{M-w_< zs@uM=E!ohOpL8?xwS{cSCgb4l2|gd85bAEeZ46-`MMw58 zSd0%@3qF5LcR|+KZA~7w1Nr_`hY|rpFK)|`B@hFfDB|24qKKI}cc_YRHSF?t^Gm@atYH|;KuXLq zC&To}-!o4rQ||q+%BrhcS*=d8xILceLK8njHCYf{pnl+nrmBfTnt9kB3o6kkC46*h zpJ4I$)3&yWl!mta<=HV+^3<_?btvCx$}p1AhExJR8`HK5hWiu=H$Fb(>XMf5oAPh+ z@dSJ+R|*xV^mr(23`0dvynV-3DO7dWIc?c1SfaGR^=ig(r}zwnx%^aR#L>wg!Yt)Q z5kddCxq7SBTQ??o{qdGrw&hCHgcuxdd>ZA#L<$k#2%2WaNB9TrXud9Osec&#_*?-1 z`bVac#Nt-zX`h>tO5$|RXs1zQttfYJjQ_f#W)?jY)fq+sc?V2+5nW(iLiXzj2u1-* z@%S;2YStp8pgep4Yjv1*d@Ws)(vh-}LCrsv*QlKudXXTk^3O+UmqK)@p6Mll%mj6h zmf}eQRJovwZd{>qS0|k5=*jUv_P)9pRo}LTp1@CFOh?q>G(Y270>`4a8NGKVF*uBJ z$(I*jUHjcK$+@qDL-VR?G+B|UDK~4e-0GVhMOC>6R???;+{zHz-^KHotYh*NT`4WZ zuR+Iiy$m$d`0G4V!vB!d=TW-Z_u`*B-o#+PF=ba-)d?Z^{6qVq3nBV2tNu6j1J(>f z2288PxoEK*M{{GNYm>`P!0Bu_=9Nq;i~anvhkLXH9|AHSXgWjCdGjK$SU$U#SgY0+ z(iN|wsS2v6kRFdvJZ%PFeH87=(lW}M4iGb5z7BOVsOkM(SWx9`Lr+hv$D4Yp?|#2r zmZ&T{?}i3DH&>4y1X5+{JC!(r3p(`PkIyFd-OSZFts<;Bt_lzf3Cvh-H z&6kvh&aABG+mM;Qf~e+5nINp`6=_7ud&Y{0B>ufn*xj5%-q;O$sW{ZgwjrP8)x5I- zAKe5vm0*omjZiyH8IfR_IB~ypN2%|BMwZ4(EB)LWZAA;k@g>nmP=mbEMWU@tg0I&5 zj@fZ!%xWRkKfb)hfvSgxqDtgJ85a%yez1W6Rai|Vz6VpkQ5Dyy;^T*0j|6iqys7;$ zDry-misrppj|?%Hsh}Bf1~n*Cz8A5enfN%H?5l#7WX--|T@*)U-EpPybw{NO}}Pl#?1=;@jX*%mX&L?Nbmv4ZguRHY^Znbnqd#){JdswQDoI5iDYpkT-4Vpqz9n*4f6 zTH}%~W^iO`I*%pm$5F+0KZ7MbQ@0P!v)NR78T9W6BmhSfx-Wjc*S%p=B8%^_)m!XT z>TS`syesF$-{rZvg~0Tf9Zg9$kW9Cd%9pe+Y23CvA6pmkndg~qQ!=_6BeBmEERpPS zc-5^xcH9-^(trTp7ABy4b=PxUNh)PFi|f}9k*9)?-dR6C1{)1VewTCvnyAUF^pi=` z4dqk>Ep|27^B{Pd@6rST@mE%MDju{;V-s%pi{~BluINaY%|_5kME#=3qvuO>u|4^1 zs!fWyo7;Wbk`;1zx1XQi^S9|OrdkrkF-&*BmgbRenC!h1~F1r@Zeok zRhSo{D3anGVVzlZ2@S#s8cF9*!GB6B zUSJ1Pm>CG8i~AhEFYEi=!JRZD73mal{L%1BHT0l7rYW!ZUs?cRstCmrlo`KYt}`S+ zD7I9!JUH75oNRMH@+ln)C(|Z-MMgvr>R7CY5Tw!@dx|p3j_E)BzXLl=Uy1g! z?+ZWRJ(7$zZS7Pfw`8>Xs1`R055+LkY6P{X!Y9mtmuQsExCWp7ry-+-^H}(7(p%gb zWp2p?6?9y$`P|>}FEq6Bn1p~x0@hQO=#0x45q&Q8zfm$BARHvTRCKuvj#TZmr%5?0 zRHJ_VuMkN`5yg{nC>!Pq5ysOxb!FQrUiSpVL{Y_ZEV6hkxz>i3|9T_|XISvLl>oCB zT-Bz$yCzL9kEb~qj{zL!P=Osho-v8_I@;rL+6M?y75ho@Cm7LJT>?{ zM{>Go&5~rf1;V8gDO8mN-QP5^RaG(!9mtItf6B>AP&tG}qFG*6Sg92xjA>TclR_y~ zsD##K5Bu}WWF>tm8Tc1*CA3H%r~DA`&RpTcjb1}wd$N7^8MQE8MXrFFtv;LI$}ZHh zwih2iTmdArI;P(zgAwr)k#aFlt+;~LT2m-4of12-mGTE`iWP$n|swqsKe zrt4h7ms)(1?l5K;Q4bdoclBB1CU4sWI|~qC?1S~KT$!D*MzVwt01V!vR+_M9=8v47 zWfu2C;wRIu;F&+!gCEZ^lt5P9WkBIB+c39jOjFbSE5{@e)Xl@f8vC8;sj@aCjg!bF0kOtn`(ugZ zEK(|=sM9%LCh`YxZl3rlj+N?Ok8+!;n4&NP%i;QGIpiAHjr89%k8$C%n(VcB|23bl z#tX9U1AI#c^S_J}g<4ryj}j)OFUN-+N1V}($U;NR?cn}RSk>8wmdeOTAmYb|WBv^N zY!$z9_c`(^AB@M5|#KM0}ydf=HjVhJwYr@o(nk;Z-6} zVVVb~`sEIBy{cM2R;sapII}WqGx?VtWVKyh~==*V7f; zH9mZJ%$V1UN!tIwUS#DG1zPW-@k!zp{?3Jda1^|=ruZX7z`7Q-hFM_c`%TTjj{C1c z@>XPum95W&S)Yvx$d;I2=cu+t%$e&)hscJ~Q(`_dQ{zwtd?s6_u0JNj-=-r5d!!Hu zgDCgfkOQgvbtYJYduPr6k*Tnb7ZmEI`=nA?tIi}M3bKr^p<1cRv9?G_%Ty3oDbrow z$&uG%OWuT1veBzoVTUQ^2^qT~{Lu$cRmBWn5=eGq{`ZK3`Xrivf6Ef5z#UyTpd=(R z?8h>UWS_suQD~!#i?KEZu5bD;wd?trmH3ss-QScS+>S3@Mjl=}$s2fai zvH_qEU5rWdGXUzy^65xGekRNt2EH^i9-ipNrmvf<@9c_Be?QVUw2@U=|5DSVT^f3<(#>*R_KmSttELBDQ0R?O2ZW-Lt&~ZKc(X4GX5G< z?wg&7u>An=hVuopRQyDZn@c?8j&hK)Kp~6>(QKypGAGA^`KpB&r(XcZ5VsM&d3X8TYuqWBG6?s6~f5wtB0v)N{TYwC)qODVk&PKatr=bQThr1pq5_Xq1>g`gE*VXKb+o_nfx`}?Re|X5Bs=$sStiP7wwUdcfJXLf$8f#Ka48`Md4wHRqCpJ$4Z`(OJGDCLZgfm_8>M6PCv$zC)PCsNfm3$4PMIhka# ztVQVzHEpZ0bIe zq=r^gW4f?nsH(r0>3K_BM_>If;xp`Yh#reWa+W(t6)y>x@GxE7*bS>-@cb#)&UoVl z>9$Cf)8!GJPnIiHb~E>rtl-QdZsH=2I<>Tj;AB?{AdKW2p(-;M6Hpn}U$%QxKgwMK z_crC26;8WOvYkXMQ;85@M~{(PDOZ(TxJ`djocJy-qD>0y%`wqAz%Ib+RAE|zOlJd9{UyE zC+)}PL6l_~6#SghYd^t5)s>K3M$=nExi+|PX8%_IxShPExVm(JlcK+g7`x*zo>#u@ ziM=HIm|@;DkO5)@rcUWA2aZ}9#$4>M&$TYPQ+UB;tj31-aiC!irQ~K2GR@EbYuoSX zH7ZUgq%RM~JoSOj213oE?QxWzS1hBBG3R7rwW+V(z7tp7M^&>t5641#D_ys;EADR!Xjtm*CL z=5v0PVTXM`2*v!~Y>}=kQL{=_%yg$x1)XLQirk{Gw1`5ry;21k&h)gbo=3IH!Zaoi zzs}0z1?FI4D2B4NO$?`hv=p!VXRQ?L7 zR>M2*>}EI~ID&vL337A&fuB41>X7oG08^AE{{-JNK1#~*!;%#jz7QbSdbWG6%CkE+ zH@P-Dw^Y%8z3Y!b6y7t5thGSrGgVb&NbDopn0xDc6l6UlcS*wWJ6PKL$sMc$gV1bV zxj%!eu`VUuY93n*k3LUr0La8l`tweZ5OdQv9!zTX%~Hedxy5ie%2?{xt_%kO72z-* zaY`z27Ip4;f3S|de(8LVO*ucaPdu;|GX>5OACgR$(Vj!bgR>w*D<{GT}QstFS14vXaNq~<1M3}kGm zN!JtNB@gTVm3oNdzloIF;kGCka++P?fq!foKH_-wl*n?}{z8RSTsJclaB! zVn&4l)4$z>YOS(#EJse9R1LaMY)O&y^>cLCM8#S^qhitJkE!*kB%~qdKedj65`!|s z|GXc61ptCb5rWhtKScw$++U2__NR6Q1HT+y9CIg&E90rw%js#c@omj;&VK>RZ*%Vw!qt^r!|BYR}gX6$^q8bJ9f4rAxGvKqa6&KNjm zW`(L$@_ZdcP*}nYbbEDmd!e_8uFYhCd9W&-*%5~RS<*|@m@M$jJiC(fYumthTX{U{ zM}Th{g{g-`-i_mB3#_Jdef|E~O!KC#ZGHAi2+om~l@0*pyseoDm(*%CE>xw1F#PqT zC1Oz8Pl;TZslBlDwkKpV5v|Z9@tH{`G}~;6r__QjB`CpWlB4(71C+qW$nv2Gqk8ru zb)O4<2$3!0idk%^QVt2R7FH6fKjL6XCRR}s+eaN@DK=okzv_H)wG#f4SM{*5Q}==D z+~5$8Cg=zZDKmb!x_V6!kKs{G zoQ=2Qb&JDzryMIonmz6X59m-as2CQbMk-d$JY+6S0uSR^7Hu0f$QU4}k{PgB=ucpb z+*3E&U`}8HPe%gp*8A6C$F$pT4?J3+N^VqB!IaGTteK0Nn`B|Cv!2MM@m7adrR5_S z4W+My6HSPE{8(?FL{gNOYC?x;;{jK!I& zi9-MG$a@(~L;PDxmzo<5GP^NB_k>*Bxq^S953(g~$-Xg4W^G(2i6pI0Fpi8fPLM2G z_*!{y1q{Hjsi)~UQ%6Up8QF^Ul4h2fvl?$1#Hxg*i{NOpKmflZy2M*1K=w~_x$VMn z0Dgxl*I4~DSr`d9;lihKB4n z*Sb7SRa1R=bv23$(F5H84)2{0AwgWIbE?0SRv{PKkeEqUwqom^lfc_o3!@vgnhD)U zB8v;0QT6yzTbqQVcCcRv?uw&Su$161N(t8_i%+vM-HL4qm&TkVhl<6anKd--c;hfB zznpopuv8Q*dClw-XL3f3?JI>fvp#MKvzzc7Kju?J z#8L+@C8CyS1hdvnU=kB6(YVL%92!i3Qj45S|GsUCPH*ST(`O=zYoKr*#(%JJ2t|W2 zw)tbIp%CvmV-k)+U9w~v;51lj0#FGX;2WR>sx`yR*b<9<+jk!R?@w?l_|BXn50Z`H z|Cs8h^s^w7I_%@tq8>;=cS4<4DwyWiXCPlmD5||ikO8p!TYQ+^DE0307Q+K`w$Pb@ zv?jtWsg#03Xx8suQ^0GIBk6*rb?ZMin4j@LHlA`6|Kk4F+<$YE^#gGx<=|=$9fc^ZZ z`m`xxH(V&0NO9Mp%&bN1UUt&2P51FjL00~SMv(oDv5CBmypsndWC+!Od?uZ*q5F4` zr>fdku?ODw{LfUNJY`V)p)4w%7{o&Afg2jrkO*SCA( z5AAtr?+@odu3ny7BtKbwVZlQ@O)RPDk$+2^Dyx#HQ_T|XFlb7l5YJ>xI{88^1NF{*sSv}XNw8a zzwa9%tNR=@N3JNtcAC=Dmt{y{q_SXtTc;4-?}jk(ajf5N_5WTs_(LrTRNMSDNlYH# z!TjZ<+C*|HPU`QJE}i8o`fn%?H|re&zrj{g`TL38bJ)GtZoEY1`}{}N^(eYu8$PCJ zp4y^TNJ`iM3wovaR5uk~9H=^o*3m>E%wQE>9>mUL@(+zxJ;qHV6# zqe^4;oQJ9u+*OA>E@P=03Zw6qnf3#YB0o7PxwLHtqLA>9_cN-aNC|QuAXutMnuVos zs6hBe-C7>LH0p;zDeDMTSeP2J5iBxO2CEtl)bCgaW?w8p31L0fADbs+i5a$#^TXKJ zUb`aWVG~C(B>f^*XcRQk*2z}~GK_-QkEz~?;Bbv#{kGgKGfmv!Vrl-`z_}3+c~dZ6 z9_oMZHDRdkW%SBb)vv{T0vQ7aHHNLuL_ct@c4p@9>5J}ZOnP810}Q(ia*Q`CNry~?Zx-^)5q1o1f|HdfXX8@{GkPsBnJIBmDEB`1}#DN!5>ps zJ~1eL>dT7%J7-J2;Mc_45No^&AFRjdSL`_Q^_B*@gk8$mb>vPW5}1BEc;V?D#z-Xz z|E(d5D86d-{7Bf__rApK%kmG7v+3C6FnD{W@5m)46+5r3av{M0^75}|djX;RGj`A_ zg>rF!oZZb8DZUFlEbPd|!b95Gr&r|z&~jx_QlCBc?|Tat%l4JS?_I`P^n<$L(tA?Z0( zX+h%+Xw{Zpp&`%$DlsjM69}bap1zawxioq3Gz}$cb%C&8zc&@#qRd zE75SFBPGJJ0{Kuk>Yg4md{~!J3V6;}9=#7s%T>E6W+WpEVg(OgY`YlMnMB8jJ4t&Q{QbW0 zW&`3y^(~e2Bh#p;j*F<2?eo}z;;8ksIN9Ip04i=(QE2F?+_0ZurLxtR83il9FXvaJ z_T&LD&zeebXpv|XG`?@2M09zyDl8()r402W5`R#It<=87^({Y*`EZ=8LP^G)wqw@l zB3}@~RT4vn?9QGDvqu+Geqs2zihdbk@nJYQs2OEW3IAp_(hLE2qd=HiGQ+)GHriDG zUwg4|3IINjnkBaSN;u7=fQsfZ<^QvsX}h3fR>YGgWV`*D?=6W>dBYCVl7eN#NYyLE z2jibARJQU3JVx!wV}db#wfDEb;xQgvwZ>c9S;tsugm&HcR(l;M{!`Al9*-p;`Q6g_ zIx1}LSolASWQnMAQx5L`+1ra{0)5Y9uIX_C^mjgzemvV|5qeXx-dWB|&U|0mW+)?z zAr-M4NKi8AIX;O9p$w3Bl~a?tEvd6k9{Pxl{Z(0Y(Z$s{i`%`ustVl$`Q|2EY77}U zToiRN=2!6B!UCI2QTlg!vyYtvk8-wVYC|Q3SnKQjR8#{=Nev8`Z?i>m2M3Ymm6a4` za<-2pg}X!1B*(umzkg>aE-kO9te6^`*iCdi^w83p%I3B&-?zGLar8DZQMq@{Lq>_9 zC{?22YvaA9vfLYva|5dwC`bqjC9b$uUP31$8yg$L!|6iBQl{8*EM%N`%=$xjat)m< zjB_Q;%sgjH#Oc&1;pz)a3c9m+ype4=_meo>RDT4K<~20zKV2J{X=OM%I{s~`Ev2^I z>y9c~UDcA~$~if{prxn(!Bk~xI>+nWT0+Xym0KH}#r!XYttjKFzVBxABopI+|I2I4 z@MPr!^y;FdyLnfpvUwZ}%gYUob%`BTd3iAdy)yrDd^?dErCZU|;^pNX9iwOYtX-d% zhpj9kB7!HE%@HnY1U-5K5qblO8_jn{28OcY-cS)8zV*4;RV@WwT}@5ePpeyqWRiyc znwnEklvIutd-re$qYi`jVs>;`?}^(WL?yac zSh-3SUbO##RWPygzjx_tu%y$7QuXabEQa`UbUzqg)_`H0*OAPhb@1u}n=o9`TA`%y zLz_6vEe$LBd0YwgzQ-`WEyvfQFnbn6i$=`&+0kxUwRI;Xnf9SR5qD+&t-qh#qy4AO zWaz5px2}(l6}DhX1vM~D(<<4q2x;$_%NSEBVcce@Hsq+{wq4wu4^ZO_&f}a+rF&WR z17pR#9yrl)3B#2hFv#YW^>HPv$Zv}={!nAz$2zvo%V$y#4xC+yyMczL zm@SDEH(-)*@apz_8k=zN{aCKkDAUe(Ick_1nMH7{Iu@HtE_sQekOJ*QX-PLDW;Zx& zJWtd&9Tb##%KEhk$#FD#m4dlsxQ~9+gS48>tKtPohqWBUD9M_-j7*Gt7)KETx;cW& z*jUz&pUSAadSiNDsElz1pEtd;UU!(P%i7o|7;N_=DJ|-{ocF@x?l>ZuWeGTgN1)aX zFD*H;Y3z^-zAkQ~7q8Ff8y&l^=#beri09#i7C6YWbKxIzvY_X~ncG#!JFAP+Zk6}F zt7zmfkH_7X5;P_7#B!=3)B5=Qd~{-@%ofjkouf(Hd*DYP5E=kXUnTH*aypd*aBt5m z!^&ALr7NxKI7rU0ZFfDad?b0Bv$$N3w}lXiA;&_5TwU-h7=0m!QR8K$m7R9SR4J9MHwhgiLUfIM8eYbf z2c{Dyi@J?3qhcgJpDH{rUnH&1zY9-hntNEUqcFAKjYHGYd3pJIRJE0PJr~+ft2&%t z?%eCmU$>(xfG^jspNV+gwr3(PT6YB`g>8D@*VEVzeBL&yM$Xsa5Cy#Vu1WMC&nH*A zxJ~b##D9*Ky+6Hf1wdURhT)(-zAcB9uew@~J9Sgf*)%sVh$P+)#>eM_KVI_Tj`omKCF(>}U+JTYm3n`Q2Y;zH?Pp}#A7 zjdSa%;Fe8!*-nKSXdfRx$_*>dO;Gfh+mH6wTPfRbC&8e{ZQfXcC{KaEKsXk@Bo_2%v^?2$uH8dm+I1 z$f+=u>vPY4s)KHAD05YbOA7?U{nD1Nl+7a9i;)3V!^LB=UNf>tJ4Q`l(p;F*wD>dy zAq++HF)kyf7D_Tm{I3<=biY563KOuIeYv`%!(mpj3GicYk{MyxGk0|%sU}!1hH12V zzY}odHCs(Pq3hT`TN+G5CLJvq7#}v1jB1>5JYJ4jf;&1ANR1xCAbKKp*WMln{oc{Y zWv;khe0xCL+_+m<<3Tx`)D4|95JNX3FwC=ttX6f`>A%4f5m|cz*C9*0KyEiX;G`Ec z1N`&tC13O7q1>ArBzW=he3i{+&8>wfOo^JyO1sHuh|}sJiiqtkoA06zkF2KTcy5(& zF)K{~aWliNbPMXS<-XNpi;JDasMY#`<^^h}UEkESea1s)5AKf}f~Ktt0bk6{M1Sab zC^*Y_%c+S;S<*+1gkjMvc0pdQFnG!sJKM|JPP_fQnL5CC*>W!1a|84ARc9w>YqulO z(5Yv=))%_!ENlMTo`)YfXX`CHy(&Q>1vIk6?q<+Yh>T4Of!TKB>DobitkJR2smAQ} zR_oi5!d)qxA0+e4I2lf8&i7_hoL%p4$rwC+t9f1a>}+}x3-nM4#|F7y?Qup znbt^i`UjKYhDHLq>h6RfJ~DRMVOxuFLW4vXUKTet-!^6j=v8$b+;>91y-iP|>v`-i zpQHz`#Zf$-zwb@lS25~%P3MOx5uObF(zVswT(G`8dwJ@EzfFmaY;6tR=**6B@vb&H zM6_gVqy9DHqfHD1JyqO84a`)iV`lf{&YC!LJX}vjBM)~(#w4Xi#r3_{@fZk~pa{v! zvCmQ+Q!IvP@M2FzZB6(R=XYnr_WenEtCX6qf|gFqJp7dy-{xlf{`v8&DU+@nSvuN% z-tFye4I*S=wJM63AGAMUX=(WqM*^z19%~~slg+8K>w*6nZJD{YwiXyQN0Z}WeA92P zR3w_WvaoQ|8zE=#`dUbMSt`Nb@pXDn>Rs&{iQY$f#Apj@3D))(A81v`Dl{rI8m1BSkDdO-(}Go9f-2!cgR- z8u_wJ88B=l1*d-=F*B$Wu(_wKgCJ_JsAM0^Zj!69Xe{q6<9o+UR=Y&~lQbsn`*Yv9 zdbZPj^A%$kEr%**(Z7EfY<$bi^re4I^M{xvC{qj4nP(}me3q>g=~V`zlTcuQ=qMPi zvD&gbGE}Guq6{Bn1mK_D|*+PC@S;L zYI^p4rAd~ao}M}0Z$;3oI_u{kdT3A<$f<(4>r2YN-tTk_gK0K;Z##=^ZEfg88}Gvu z=!RG)S*|q=b6rHs^Xt#ol|+c>B--}Z1*Lsx13X{F5QDa&nshZgZjetg)0sgtI<8*$ zT9sz2la*SJGjPDC*CwvcPK$l)#|?-Ivr))}Cwg|@=M|{0zu)aomh}Rvf~EUiP3DJ- zb;y~c+}7$&O6OHZv7)Nx%Ia){bw3ICC`ZrhISpx-MlPE@iU}`=*}y`=ep#vt&}G-- zaJD8vY_hnx5QV?me4W)KZ}VKM61vkrLj!uVt?GQ6Nn@KpfreUQZ#K8T49CF*1;Jwx z&`?YLNjVgaA#j8^H#xhw+}+)Iz4!QV+k=PGrK&m)N{{Yl;&FazM7Uo<0W_K?xEqC1{awkVGyOghb}uF7pG{8T->AH=&aS6 zj@O1)A3}rj%9$J3_?z(3F{U<`=bQOY{BN@)7i-Vk@dIO*U6+Kv4bkcwvInTZOzls5 zQdI(ME^AM{{yP&|ob6AeY*0VM=9`;SIyyu~KSbSjgw}3Z=>qrsu}5T41OjS8no(CI z?vmqPEy!Mfmoz-_{GP*M)L%Tm#B=dz4)KIBh0!_~p$eI;kPXG0Ha9ziJjlJvEYGho~Pnt92BG~y{ua^9qiV|_Of z@bOlO79xJV#nT@w7KU$VZTg8=I4v2Ce&%G)iL;z}*1Rb=HtL}Zl*IpdR z8!fq8qQ1pqRgEFv%SND1XWozVP#!Lgz1D~hI7?BNf_jbD5cn1OoQEn`eYiE`bpPIO zxU}^(SF+&JmIcomH-;lFD-qk9O9FKHM+laSq`bVe?3>GHEl~4=yReOoh{%5GbX1ti zW=<3z-p)98CS~Zu(ey6#fx11501mmR$Ai~uj6ycuZH+`k28fzKpfS%3`^nt}^bH;N z>wSrzcaBh3y)mZdr@tgVk1e>_&F^DfQJxf^g+tIUPfl!3f78=5W^o%h+*sIXqD+_e zWRU=u-iR+S-4==TTF+)I_p@ zuLMsk$c5bByVx#0S==5gSzDXg)$jnCJzDeg*_o>T6t?PWvp|rF`j*-rchOy~w3sjb zEh}3~T1(#_Qm|DxETf5g2QK*pRyJn5(jScvr*uMPmWs07@8%b?yB;rI?d(0+;}mdN zXi~qDN3U=LoyL#sxw-CUz1!?vXk^FxJNGJe_}bd|?3ek@TK4B@WV5*)?noUuc5wNoX~NV{Jb6jjrBE}rXU$&G1O>(_6}yga!RVKBcZe2$8UaIgu;)qWmM}=gl6TzzNK)pxHZmw3>-}4}!6W zWw>BX;B6g?us7-4FA?@o4sN@`*5g&RGt?!GQ zq#w&YP7Lo6k(W`3nM)(^a?j4d+ru;)vUXW zgpVCg=!7Z}ZEC)=5|j49;O%24(W{ncvl}4G=B-{&bK?8!C~9k~sJxDJ@&h5yeT1By zF0IdjfE(xp6_3Z6yv{^1PaO{51G!Q%k0_?x7IXtn&p<~Umq6H-Ut`VXV7y(8p>9QN zC5iOg#)8-VM0qxqC2EOb_d=fpD%?CIYgQ)>@K6LRnL2sGog9nX!#xQew5$$b;B`>iJg!OL_^7sUSe zn}4?5b87OO#AX}#`fxsQYYiw_LbC4ZbI_EXCJ~R{YF!DQFRLv6Ue!HrNHzSAbeTI+ zltJCDC&|*bTceDr{i;qZ=V`l@OHZTyC^reWs;A=@Vu>WdCEJ(4Mg4i1ltWx~UY(Nn z&sRHLGLd>BZh3JDJu=5mXny_O2d+NU212aWN`-8`?&S*IM`+LuNnQuffjz!>zRnNa z7ILrcBI)7KE)LLmr5wQLVzrQ$Y*kh_ftQVC5jG ziAz`{B}84jnI|dvVIzGKQ26WE4;gNRDx-elWM=my&rlx9+_U}KnIazYPo z!V8x=L;xQNHl5RqFkhAjQ=?Pxi@GJ3NV@lyh@<~cscJEal(^E+LQ$qbNWoDKpCo9= zl4#_y)g;DZ_i7nOl!z6_XjusqosO=;)3DP^zs`JA=ckjpE-SXcB+DH3uW%SKmw`>K z8>Xf+Nl=<;UEZRYQK|qK6NN+XQ$L&1m>|i63tK(Y3 zYhDij{q9QoN=ZVV*3)(nqQ&LXv%rN~v)guUX({j8I(pHWceOVm@d$a9dv`R>bg_QC^n8_2W2j~zDDg~|Z-@7<}LsfA`MH`Z^ltU;Lc4<>be zBYH?cdwwZ;?!5O7_mtmx&hiKb*`qU@zQ!-wp!Z-@ALb78xPr4bOp$K zdW3Z&PhZX)WWlP`^i=Zh!%ON`s*7s;&j!s+M7U*X847X) zOt()22NdoX9vfc_x8IbGhTlj|lIPu@r6ml>`C`em3+-cn{krS(FI-COU_UNjsVE~o znn+1wt^d{89h=ym+J{8g_Qx8;(fe?+o;`=#<4##gNzF`6!d$|HW2-@A0L~R1HuYuI zOR=Lv)S0xnwX};*n`=QIwnqFzs@q*}jwK`6J)_YF>E17qPjI zX=n_bQl%)}6Ri-M!Kl6%{zyZ0vh)>Bl3F*f);>Ov7JM;x!C3;Ii$7^em6(Nsxgj#W z#UG*RQNfs7H%dHm?3rR@B}nd|8+G6OdnJkdF$5t^Y^oS(N~$a7pDTOgx<3-?;uwp< zxSUyAq#Dn-nW>8mKlR(G+MbSNCl?;EL7486^nlpj6iB6G#5!4GE)L2d2YoZ(-zWLM zg4xtqt)PtAOUPP$X=Z%l6mx)p*q7tO{9?;B#M@Dg!c3CD=2A1m;|o8)=vxZ4ljIm?G~xVRZi2;oGAA56$eT(Drq!6?49d!p$(p zlo%}B?{x>PLHY7+Cmxyh5+#eq(PW~g4_}3H@p;uMtbZ_dg1-{`&8*Hlw>~TDf_l!I zyg@zSfqwU`e|{4go4u85PuB@CBz%FVK{RsdH5XBbTwS+?YtbK@o14GJv~EhH z)LIM>&>}6KOy)E_J%SEnHdb5VW3-k#AsrQ~plbWuRzBO;9kj*Gs0WgsBMd@$C7i4IQjwZvGsWFw|)6 ziP+TMFEZ5NjE1mr(e0aadfMj+GHgo_#q`L<&3t-`h}iV>@YvT9F*k=NChIW2eSu=7 z<6aas(TFJ)xci3t-N8t`$Az(1bdaX%jMvq_H=eo!CY^Sx-9l(5Y%_u;fCS)FT~t(N zzqS`QonW^|KXQ~yEbt9ihDpN-*2a3oe6_L5_3WwUyCW(6NCJ$37+%U%3iVWojMuI# zY+{Pqc61Ij2KKeVd5TG=*2*{75;YihLQT(oc2b8=HcP;LIo%S6Q9T{cOBVsM%OojM zXG~<`@`3a!EmbV6-@tSSUW(H?>wAGXIwztXIS@U@{!m_I4E!wlo z*VQoXX>w2o2)^iQY|Qk-ddE%vrLrKii5dp_?k>UnH0ZTC$F^tHBKt5^2!>_{jT{Tgo{h?veN>rLBZ0^W{K zkv!MzkwF{$&rNIYi*v%R!+mfj{slBR2Z&FH(W5wa0f~XM-bj4p#^Li1iH$o( zSqI$xg)mzq4pF{2$#tO`emr!L0D-Th=!f-^@i(m13R$&+!$QoZe5E7E1tj&90nf|& z?dEpL`0>OV;x9=P&{d2NcM`27O-2(OnfKVQ2vI)x*oB)8WtvHg}7y%uSpD>02&_$%Pc?nwi~r@~bD) zlAyNhdomX_lQ-4B3*$L{JLBP&JKW{kTkHzVs=yShcxw{FIA>xf?`k$Y&f0^@&&8#u zP!FACoILIor;BByzubI@Jc+h&uy4s{{zvIwl#Di*fnb(Kb{rR!ueUso6G2lVZbnHM zKRXLguX{iTlkVpJ$%K&%I{-?{iXL4zMAAd%nrs>Quyk1|k_V78<6O*R&+Kh$Vh)p`F?W?NQX4xjs| z&1CaC^pG{+ndI#_)$X3b+TzWEAx&vXC(!N=_>;YYS~PW!*kr- z@ZBqVpKBQWHE+z0i<6zPOo=JxwF*qL-#B}_<0P5i6Ie^35Ja=%Wb_hn*ZR%j6y%NkKV@{dSz8HHpJm-?(S*slgH(0w^!;=Z}L6Q zs`eXg6P0?-wW>sxn!fn1-|^GMu1j^<{GQu)=4|LhYe2^p?K7LAhgKi(KZxZvWqb=VR!;6^>4V2PHdTsV|| zv*-D6P?MKi!**(jr>D{6HaC}^do&$* z+^G%s+ME5uGyU`H@*f)2x*WeetTFH$?X*z+PY(r>g_zOL~|#qaut@rO8y<1lH1dUi}~ZM}YQE#4%4vicNFJRb_AD%fxK zB|snH0qux8I}^EVUd`E_IZS4>u}c*D?95mJQVCNycVAA=SDVh3u6rW!Z-de6-pMoA z4Nh<`AOyr22{|btL%fTNoQagg6qOQ6t4~f*vpidUsVPZGd#F&%e}s{UWk&mX2q_;* zp;S~<{GgvFc|$=CPiE43ekgRd?YLL8l#-%*{E0#W9E`*dIB$fWYdidxwlTNKXk`2h zvUU&TW=yWD<*ng>t?2YrEQ*lN&jC{`M@cP?1<>;bU(g$!}H6WA+m1nEs)aIe3X)7pqIdk#`C>!6ph8QM~J;bF%x4s2g>b z%4-#6Ms-fmY1PM5YJcowI&0Q%d=7W#Tnj`%kGq?)%8qjsz2eH!4`dfYUN3iPV1IBh z3p8TN;iCD`QBB*Eh?{R^CHrkzj^u5__3(!#ctG!+7shS}2QQB>FGMSgXBNC2vWos~ z3244j4{isqw?TbbZJYwomB!~Z*e>2~AVxzmLmhz`vfui`!kGZ~ds0&ZY-Vr01@w%> zBD9B-9bG14edti!S9O`(cB$dD?QpdfD7lt~i@@ve({h6Vr!gm(1%57CY5e6DY2L+E)i4foD$hvQl79Q$1v$(8D$Mcgibok?+H!ysq^(l%qC z0^0;`galg^Nu|dg4SqRxar6g9UUe|(cqDSKb_=86JnlT|IPc8N-NoM+x%#Y8!nF48 zUvJd;k6&)8vt9;!jkxuhh3IN4^6sCO3$IPk@U6s3H|#44ACj%b9=x2+5_bG^{TSu( zSbI6x@)m#M%GPIRD^JH>xa3>%%0~$?m#?!!3Wv&Cfxjmdu`BR!dAcj*H9-si=GX1F z`Y$Rs>Y8Cjdb)GUSQ)eU1BBAov%GnjQ?pEHt}|d%DdLFnwtrZipPSu^;Y3yRI7?Sb zY-`5F?U!~Nexezv#W_8|9wug`^ihB~0bqZ^2DeIdQMGAByJQOdv5>H4&3|>USqhDV zA8=-aAeO!4=cjMcP$`Fk?h`_+PY+>3k9i#p)z$8(5=>PkCpH$cWK8nIE&gx^Ma|_; z)fd!{tbFq~ND2=~i?Nnu=a2tqqE%OG5bd2}Ppbeb_$69FRqjTws3@^IN(vibk|+N( zwi47IFTNdn`ca|Z4T+XeMpwZqADZqf-8knaFBOg?DGpyyKsKfcm-L)>=1X6Dr5i5O z1dekrC`gLH(RJH(c75ON0^iFKKNCOnUVTe~sxCC%A6lVw<@>_o`yX=m-GS)IMkuwi zOTrJ%f<|5HSc!=mKk-@IEsZgDc&s8`@}(YRi8xUOd7BLi7I z1VeIeU^B6Wj^~?N6viGRx`eC5`}mZM2pvv<#dG25C7{{IPH_M0|MLR4G?D|a_Dd+? zOzIV?l0sa%@x3B&7;)ta5ej&zm5Ff_Ypr(ZErcLTHAXaNE0FEUtjnqFhR~VoW;4J? z@T{EcR90PeDcyc0QP`d%x4~`gyxRM07wGgEH|IAG+f~=Jw2Y`wA09x><2D%$N`l=U zY?Z7j*$I-*B`2x1KMzXrP~6&;O$5Iv^dCQar^J@O1W{gu;VTNx{NACQg%aj(Z|I%> zy?nB4h@0Uvb9yPVaiqzKu&k3TUhHv|eSS$Pcjnlscj&0IpsEh)6D3IolpFpq7|@D~ zKWRV5_p1jYW;lM?x&DrqDBp@0^jyQOgCD+aJR2793aqH8((Ck?uieMh>O4Xtnzd>Ffyjx#)r5HP`l>2_a7;nzc>nN&5TUHiLWi~i1 zi@+r0ycJM;9|1w&qrv;rN3V9v{S|tk>$@6RDidZ7uQ#;|q1-oMr+MECm!7BdPNKQf z*7WcqSEO#c6+YumGEv<4C{S3YPE1I7IDbIl=cHcnBL$97{5Zwj91}mCPB5Tx+~V@$ zupY=~qO@Y72 zUQxxqr2I&c93o6xXC;E*W$SeZGLuu2yV*PP>)z5>^KKB>j=2vk{mvZ!7}N z@iK8od(#@G+;2wU|DHI8akiyOf}O#T--omeO}S_;AM=4 z7KAnSN3XGFB_uas)jCYxDL-4wr{7V8fF?0(XGbN1y?T;KgYnp$OtL~?i3 zXV~VK(hL%}BN7t#J49#`PXnko^3|Y?FXvk}<hLxAc_U5uTJD>C^K8cTN}KCq zf*epWtNml!pFc1~S+d!w4iD(@cBZG{5sU*OFEM}zNGOG2ZMDc>RK0#vhu(x0zU#>^ zx}$&{LBKXRjZ6j&l9)2RQC&peYwwVHgC!I&Iqi)_z{jS?ORUq z_@xNuw|N!Bc;pCbZvi8KkG2td^;zW2pFF0Qq+r6+QF2dHAI*5<$whQVQ*S|##=EEwU_-11^K zZD(rVHPs@C81(6Kg`esnj2TZqH?G0Hl>k@Q*cPf6X|VPCe$8FjN@+!=}~!H<5_E? zd=cSsW6FlQcbx|E`?kbJ8Yo0M1ue;m+xYuAR(L?ilo24*Qk|!FKJQzm%4*D#U3Tc; zQQ2=NB_u|ajPuDI5#&e@{Y)Ao;}<<{8rjrH{vjfgj+mkRE%0D}4L}}uDnPg8CY+0_y>(v) z&;(<<kX(<`^#h|%c12X$(gaY-y{s8uvwBLBSi^WAc8QMx5$;+UpOSt?axR6bJbX<^7e+kI()d z?-iCU7XI7^6NN61M@m3R8nVlJLRi^WS5e`;p21}Ox}~EnxwJStJLBQ9mO9vW0iJ=E zN*257ryp5;$096m4%j=GuqCGt7m}$*N)@J3EG@Mh;*2~$B~(>Y1DdQ@d6-Ih>-0Sh zj(8sitEeqPKTRqHls7N$zf81kQoo?eA+tH_`j5!EuwrM18c@b$aSP(+8-27>dBX)h zp1K8kIN6~_Dt;tlE4+pxgq`ODl5?={@8`{1}23h3bqF8 zsEc6*m0gK12w6a}c)jZ&FYYAemv6D?Oag;|uR)Aqz~*>7o=>8A5-b9nf`K6{vQH;I z_lJjbanz>F#h;a5QX`i4X=K@u(FLERM8%KmnYkx1Ng9L_?5rJRbciM``uArVI2!73 z8H;<5@200$ruu?&uAjn;NIJXFxY9BKMhErBY}ezp8jgqFwT_2vmsSt9V*ZVFd37BX zXkAT`Qg%$m(zh|(e@Wii5-q)Kt18O+;J22Ttp(sd)Rw-}G<42XxIFhoWB{ z!TsFa8ir@VZ}ELP{}OpruD4uP$q^wNHI+7~YuRoZwV3-6*9P_830)OYqsSv6F~C+9 ztyBso!fzJY^ygqOblHjoO}`R@(1Mif$W!VF7y+-#)rIeEPB-b_`ZxR%RG*+1bLs>x z>?8g+LxuakoCDsZO?YKWr6?6aVSK3rLNZ1Y^2XM$M)Bj=EEm9=?^!?rCpRB!nLdAa zsPN}x0shvm#9`F(uDC|;dB%s=Yi|px6|`^adVkP+FK;jRk5P^#-H(BK6n+ELqnl~$ zfH;gqqgqZInSN-s-dLS3Ut~|V$3Y1xj$x=?VkLBYx`84nh^kqdupY{dEzivbYB_Hs zM5w5!G_w;7y1neya>`gNvZU85PUxkEx^KbZ3N&ZFckJ__r{4T#eUn zOg?ZqsTP(e_*IdkqbIAj9j;K>w)FWW_ksZCkPo!IftFXu*{b<1BY(DV;$;QW)d3Y; zH4jlyZnxfyBCl=R?<9`QJ|9Hhnliog0l_E$h72kFQ=E(>A29HpKrTNGV9x8t-H$NI z4rP!s1|zM@%geL)9T}hPPe&}PT1~umg}DzYvp(s#9CKg|%B6EQkRYH-YTFM)&6`i4neps+s{$ltl%b3rQLF6w;S|(YJo2#K0K`kku9V&(K$EwdauwODLSlZ4oi)f z-SOU?`LsBrMm;#@eMnDVojkc8CBeUJxU9^=ZfNRw_F#BobRW0!W%m8PA`FX~s%Bb_ zp#}IxviS{YWYeQxVoS(uJ)=ozDht?LYMFpQ^7Z+O4j z#6>{ZZN~=w_W#!o1w@65%Pr{xFpRVZ^xm|a4LNn z5LmBGDr>Ux@`TVckt{dajz3*v(a5EBxs8lJ-(EM$lI*}ltnYyLc4L}sHrN~O!D}KB zcIazJ3DgtWUHXt#E`g<+$z#jPvf^^Z`Qs>9t-$8Vmxh#VF-8}W*GqzuKNr*`{(v*n z#)kSok`g1`_y(CFLteX`?QYa^o-`LL(~=-eUnSQ6HvV?O)B#X~xhDN-Om~n8CW*<@ zP=Z9cC_+p20DhP~6TqFIss{(2W zt|G!|B4RFuy8M+2pE{oZBl9dAzba&AB4(PZ6t|l;uPeIqMN75xZQJF!v-Ecd1-|hQ-OQ-Hy3*O;Qi9#>qQhiC?72*`d z`SC?>0cvv)@A>}xVw8sGGM;IGAH8ffQ7~{XlizatY|ke=DM`h^QqezXeKX4+oV%yP zgvhI|TZtKBw1|%G){~>FEtDN=JUYncb5sM&9jy+(A+Tzt0&Ep6{v@;5wfKGO|N9K( zcERqZ!;U+K4)O}G-f*&Y{mlcpJY5?f+B-of-ypLN8*l-r_2 zT5EQg*oO}i5uYQ#>>HxYB06sqaC^@?UuDv4@Obkd=-zEHXFC`#AynzI9Qq7ceevI? zZ|04Rf&R!As(gqDm*QPrwcp{bth2GS?SiVRO_}E2{iy4-IgOpRy*t}brYiCA*J*{i zNCj~Jwzm_vw|d5Bq++0=ulD?4vj6EgNtE-x1BOQl;i(+nWn=@{Mfus|!~>|n5!zHY zbOj8jWZjn5& z;%4~~nNyH7W*1Fz$O{UlqSTI2(rS3?LJ?X43&AnTaVscUr-;r6e-e>*W zDiuz{fqWEwgJbF!Dv9gDYPjys3ZsKLGbY2*2l6 z@i3ST$?3CkIWNu%^dS;4+zT={C;IA|{>v|k`^&`EVd3b2f%83c2JRFo2V?+avlPE6 zKqk4w`?Dki*)ATJpBVrk!BRHavAIkVXm~p13JDEipY;5POvphO*crf3D>FjBND>f& z372{vDM!G$9m&%PBeMx`CA$1P*?xdjZ?%Vdt4kJ@JCeAT`O52VSg>gE{(YbB`(*E~ z54Ynv)X)C>i4!MI6qS@-y!uj^C~e+5JAClKw)fs^Z*A@E?cKd=$ESNg0qoe}!*~$Q z?wudK^Y;2B7cQ6Q8y?!fA5okdiS0okCrq0Bh5NsD)r~hbG|*GPWD%cDy#NBgqVQJV zd#4y#7UEvUVKR|Z9XR6sb$96C=!kpAOn1pi zLf9%@HLWAbPNp(>7?TzuXAPp^(MDCc7rH3zO+?=kqe#~K6`DS&IT!;OaSL+*3Crw7N;uJQEfvZ&@#%V)$5kkn(3R-&6JhKMjIA4eyc?MR z0C+YKnrvR~Jk4x`kcWUh7*8mqIxt|MC5wOkY%(fsDim+zX|Hl}GQ($-Nk9EBp{OF(XlRv0`b~(k*B5 zgfmzxN=dO)A=O=Lk8)#6!~$sUz}DK{zW>kx7ak{Am6w**)zmh(wEX|gcH#o^88O~tFGxm7Q%#1RVY>SyC zOaI%q?marUs=IFAD>?b|W$t53r_OoLdrm?3?JiY!m6ex;g@&djrvQwL*qZwK((+Qs z0e=438JTpU3ucTWtkNp4!aIVhG2}vmgELaogja22!`+e+)`8B*Ai=BDgZ!PFD`G`e zMNQ4k;#;*fHP{3rG$brKCMqpCH99<^wY~k$-8(I9txiwn7j`;MOHBnKEFJec%B!l8 zhJSy*xVZSJaO3^y{L|XrR($(*TSq(Qaqz;s*hoA!>-K4EX+|1QVNF?5Vq%zMnL=B8 zYw_(n@PZzd#bcP^;SrG$VL{4-5n7|QwfWl3n-!H+c-R)(OGZRQreF`vgam)T?z=gw z<4*9qCFFg9My%DKB1%4txhw!;?5TQkhhRSz_v(qP&qpwOthkZ@?-0mhP!mcjlHODx zDwT)uUJ;U4FAAF6L=>up6qGWDfRBK2S;F6^q8d@u$lGu#k9`Kyh*s@pN^2)NFs7Ws z*Ob-^r8vgn;Qb-vjAG?zr5A4EykX~9wyKPt-^%M_eNS4T`wL94pbt2m=uV+D?OIbM z7&LupE)Vr%5M@o)U$_a~1WI<#lk zq-oPlE6&dZSOR(c;6c>=)Ul%#C3i)7%}vcEC3hR@>MooXWx1-Cn1LaBnWS5I4eJElGZj3tt8 ze^|pa>5o10EY|arS4e1RU$aeY0iYHWN!rYin0r?cW4ni(rakMCQ;c(uFgpN`$BUf| zJ(AeTgfONN`Md;|$;(GUy&YjhntGWg$y^PQpAmT+(sQNct1sGgl=7E~`VatffSvks zJjqf}!6V2`IZc&dlX!?9f|Uzl0MNJ-%tCN$2MbKEs?oddjqe@1&B5UvDI3i;`3Z)) zknC+xf%nuJqJM$aj4R(F``{&V@WGM0cpTCkOftT(`CzT9Z>+4Y5-L1ahqw5LhlccX ze^Wg$RW(&zoy1`Xu`#h>VZn~K?7QjR3dXlD($swk_u!y2LX}ojHZ?Z!pr~DXYOS?g;GnpesFb86e}6w)S1p`WkwxQc`cYFaI~~WFIiB=fsbr^9*mRWg zmx{Uou0xD4y6z{CnSie3J%!o2W2a93*H3?7V7bBN8ypZkXWFzUA6~ojz@Fc~`pW3x zBfk0NufmAgz@7W|{QBkJLynJ*{rZ=_k~_fpu%V!-s;m9}&wqUO(gpv1e!~h1{?Avx zCcL)q-Szy-e~_nl`t(VTOZetL{!^CY(bt-WhPU7U)5h&vs;aA-TbesMJ2AtErI&t2?-3DK55eDo_dx8TH0Fn9z5{FU;iA+#&je#J?}i; z*<7}niVrTF*=i(-R6UY=yYfzhS54|98ZPF9lE48$L22n3;SmwnFJC^mcXw%NX-hLU znF-}}G-P&zhrhVu4UD&M-AK>O%pX3SBV!X2FBhG^eEH(J6US21j6VyDPf|+i?D-3` z2jyZH?3&7|Qe?W)g$BeQ{EC zNRYg(^nk&{!D{BhN_q|^S*k4QNnn|i&aFp;mgkA~lyd9#bV&juI!gizr3>i!H-#nv zKAeQhp6MJ2A6z&y+<7u&5;gTq?$)v7-KXqMMSXoE)`WMfgao`}SS6<>goK4O)HGC9 z)?$S?jEEJCSH5HO^@z~^*e4n)wafeVt*x%BDzCwxaH3+OqT?d4WX2$7eSf#O3^$fo zj~*En8519gH}e`iEiiY~e*tp2>=x{uvWcXqwQCH`k8>JN`<<&LyII)I=2FE4F zB&Q|@hXjfJ6l*JME30ZlLZFW}2=wEN>~4QOyPpO_j4q4@e;6Ih=3TKceHPEczEksVhx8 zoxCK$E3o8CMaqOU5PHZgRf;o9GKLd1_ST(yUwr*dyd4D?Je7bsyNHOeg|lbQnm#E! zG*paF1M}LO7-r-J@E1Sx^o*$!eYp1#og%W;R~cWy{X7Nu`LAC*cJdUjAeeviv+pIu z#@@Tv`I{GC+qiXSU3~*~cMa(89~v6+0lxp*$eqf+9-Yvv0wc%;8wE9uNLE##ULXM&@$C!dme!WK zx;o+2+SYpV+}Qv>zpRYRsE9}?+k@I#+S{3UIPt0;&xpwvCx8JC&t_}zYxYy-8QMcc_cm38)cm)Lo4I7d#au8mjL7^C3 z$Pn)@+O%Wa_kRA<+a-5K3@yyc%mkpew(fSx-O9=;tdzvU%!Y=>x>_UW_Rf1{Wu$h&a^ycg36%|MWixXEbT_%<{ilF#e-`Kco@7~gia-4W; z>Z%`Hx-ucoI8Tb)s;X<>-MD_st{u<}%*y`6s)r!pQ{T|AbKjo2h6XegPQ#cpt*ETL zQ*t+XPr~^N7r*hPf5d7~>~4JH*3Iw!_=h|9?+ptH9#uFpG$g30xwX8y;_QVBS8rZ- zFgbrv*FeCBw~Q0PehG{fOdWg7)foJIC~}h|t31M2f{CKUNl0m|N-$F@QhEg@FG-+G zxZaAnt*Nd0Zaqi#T&BAo5+qJjy{S@i^jHK5FyfX?PpG$g3?u+NoT-Zd2$OUpsp9o) zlrRYrgOwgI)dUUnYD!jmL87_{JoKtbC{;pecw7}A9QjXj5nwWL9#Tp|_1%W%hBizI zKu<>j2~hF_mdg^4RnZ3*P9=_3l>`||xp6O&c|}O65xm7~sQxxh>|Vv22C&(kdP;rg zS&bn*r*sHe9gV79)gTa*Dxsv&7(OEj!c;~Cnz}BX?2ibEo5GP{-V{ZZr-@IdHwqtS=p=Cej*^y z`IHLh5wgf%)R$`@CnY7NWMrK>e(cKS3w3q1cnpNQ-@JbP%B7285fLeAX^}1yUCI!Z zX-i6tW{Oc;n@7K|veY|LQ4QTIVudKu&WHNin)5#a%W{$eHtN`a}YsKyMwpu#YRG9Vx@DmFSG$hln-G|DmL7FpEG zP30#)rnNYZE^Zxa3Z>g)P2Mnvz`eTWBfPZ{S{G@54 zC(oG>78yp4RJ#w}+_e3}y-khHBgPK>0!A=F z$Y^5=vaKKP+4%0(F_T6<^vPur4vpqedFQnccWvI++SEF7{II9Lur?~jIOpHGcISuR z`ejwQ@y93A&+xHBpZoITSTFyBZ~dyWq=I8a8!mrj{*njh28X&{J4J_(tIdTAuPijC z<3~Rpc<5f2rSznt8qPF-6=p|+u_RjZ+KrnVx9nOxe|AP%ax-Qhs;b`Ku=VKiQ)h~b z{^<*U7vOJP!d<<7bK}-s3+GNxOU8RzoWn3QJOrnd?zwF{Q5nWYUF54H=G`qGIC6Z; zj@_Sq=1Edv2>AK+M`pVZ9Q@u-pO1uN1TwH8xOB-f_nwp)vcQ20d@?Bd zNMz1*ZlcDjF@b`ZX&8{kEr3~pn8{Z{S?K{IUm<9a5~o$sK7({VBtoI5A_;hedHB%6 z|Ni=a2u(>v+24KjOZYp=bB{kYqF@-#znD)?PKYPGB;ZmuBq#)PbL&3(@To^1iwFr5 zfQIJgZM%2X)Ye8tL^QRuNL~Sc0hnj}&cA$v>jQ%V6CHc<*4Ed(yW#!Pin29}mp%39 z<1taj8M3LRwYIjlt*tdQBqS&(@bk}ny0O6+b$A0iLcjFZ>m%}qiDjEm289I2L`S2_ zbxn@2=} z`@=1pP)+_|HMz&d#{K;hPYuf-hF2n0*Vi08di3|Ny|i)LmV&&z&wuJUWP9d9(dOOT zQLq2_%D?30=35 zI3BX;sOA{(m|`R<({T4mDWxYBMbOsP_N!;gD(@Q8-0qhfJxxlmfyMp1)C*eQ@}ao; z(17`z7_Wmz9;;edIa=uvR8WcNYVuZrC|ea^8p{-B6;x!UCL;xAL2@Cz zMj+aT+P3;C+@?hfusBXNj*`jrxRx*;i~93`dTiQBQu4COQ`Orc_gJI9S%5qpDYYgq zN>D0zk2e}aYvNCm`uT<+MI%%$@6FA49yJ;HlFtf=8flH zId$j+WdiryVR2X?B?uoC;=0PI{Wmi{gQ(IjP?s;1_^$$Ty^c@byWYr&Lg-dT(ELhScLQBFNS@6 zUE>FDZiW@FxVUio(%i+Z-$@!DH2eNb0bUobj9J2w-H=IZMjZd@%cI&-I}79;;h+RUE$Wz4qbpgQq$=?p-)@dGVS#n7+j0y_jc*skXLW4A-KQ7g}1{(lV1{ zN4K4RFkZ2Q^YSq#O6UU9hFsN^6I=$$^H~-J~ zp8wrz(g96Atw z`E#G*U*F-5%cQN__g%kn^R*X#HtvDpVPRodcU)InkL90W;M81p#t^@Zzet&!3DsJQ z>bUt@7YMBK2qzv?#gID*6iNld#MuaGB@hD%9I2-=brJ5y^%d_nhgMM-T1hGQRKI=+ zak2OZBzo=U?P6mK$Ih+YMTrB1pY2XPeq_OV94N;L1So4Twl5f2C&brMX@##pyV=3sw zG2^!G-mz}u`epMLipj^i#>V}J4`F)rf#D->-7ZE%!i(skfq{hug%d|TfFkBD^K%B5 zl$E}@?%iGc_kH10&o#HUUb}e%?};BaIB&(mMf6%s;VUu}vhuvQ=ncgoy;sW~$@> zl5HXZQ$K2M5=;Y7O~fyekRF1mm&Yn_LQ@7X6UGX_(@z^I^3)dB~!4PhDIX=(ja?o60y{Xr0;o%zi|e6I#PURXim62Y%la&4;g%Ox0=1O z(t8`_%*m6l{qDDTsUn_s#yhC{`}tQ_R$e@R{=(VQgNF>6FljP>pQB1%R#sMY>JeNi=}4T&aK>>g;f#+@n73r z_P{|BkH;de9mh_ZggtXMeX#y_zxdz7dv+zJrs6>uJfnT%>Xi%U&SGFpnK`Rq_z0q+ zYaP<~AgI^Es)p0*tsME+B8%hhuO(X_@}#e{mZqlD$4*_pbZx}=QFwwCvmTevUOaX9 z#Ib!xKKbQm#!j2q-@jioFh@>Zzi@5XxDnB@#>?Pw+DJ%A@DDU5D{X%CC$Szjcx7q% zvAu_I=bTfJ7Z~KsDmpgCSQ^>Y2Xm3H{OtE>S?T$shJ{ClH8-`E7T?B0VKy_@!$iH* zFGanVS621=RWz5rNxcdof&Wt93&z0e+jqgLS!1S*xMw`-+I8a4>EHbL6|B{L?jIh{ z$aXF{$<52|?^uS0v`398c;fFKjEalKYjjI*m#zEb#{D~vyzrA(M@|?%h^8YQ)jM(M z%+0ICnFG`P0{pLEzIpV}sR2W>A|hQ=maGYK@&^=Ozw^qk-yS)xFgrILtA?4eTP(R< zzWM##cW#z`;Tz9QoijcvIvfuqwzah45%TCL?0wMvG7xd++3jf|9o>KG*0tLMbF=#T z^}BlU#*uxeGP2X~^ml4T^5?(yjAMS&IKhAI@Bi~s(X}OO=1rb6KG3o5J~}ocG(5Dm z)%misXTJ1U)}XY$eVvb!!zDE{32R3!!rE1*u_RONx)pNKd2JO8#Z8v@BW{Pn*(L03gPZ*43w?sv4)1 z5e0b}sm6XL_}~d{Apk&i7aFUdqSNi4n)FCmdpL@SzlS_l^|jKAxClxKSRsAZi%#L= zN}OkG_LRrCxJXV-Od3A4;78B@^34y{&6+$7C+Ft2)?=qnUb}U3bm8!^BS)#zk$PHO zE4JV3M25~6WiYfZZG0^-Bt*QovFP%pQ|Hg-WDi)eV9}|fv%h`yrTvEw-YF|fNl0+V z_K64&pD}3)(mr|q+}TS-Bk~FWsH`Y|@1qTIF|o@QEc(qWFG?;_5mTbr@aKWyBVKv? zjk3yeJSUEx#U_)0VLKo{*C9_{Qnm(+Y)pXC{*lG}?22N{Ta7>~y}EUx8(Sg@M)`a@JBl*p*UuwF5w3faFW$yDjqwM_$6j5y~l_E2w#cNjwAUBuoZ)y zN?G~xG;G$9VDQz@imV8%A~hzSm5>QcCY2B-5tJ~hz5ASI=~$p0X3iu4E&UIfF%4ar0ZO-Fm%^($AdUb=uQD$L*Y_48|OYVr#R7%}>R zWe+_(a@;t8$wJdJ z(f+u|H`X~q9yEB!l2xm3UB9+t%jS&i>?fZ64E~<%h@l4T4$TvToH=0NW6ymCZ>!w7 zY15|l@B8^1pXR0(Y??i^aQM7s%V*4;PY-cp9SyFH+_-EMs2yxzoe)zOsg*pkIhxA@ z%KlJPiR|!`ml(<11+qRm2TMFu6oSMf6(ul=9s#NgG^Y7*BxnKR8j`bevX^~gW%l5K zSPD^kyL8h#8@8_7{Qj%&C8ecg4;dg}S%b4zJicncU}HHX4tu%U7cQSMxG=}qrNmVc^?MN`31>(rfeU6# z0^#E*;^18W_GYX{Te)^&Xn5EQKY4lMJ6mTjo`&hj@QBb+y-q!GcP zj#q`a9~XtA^YRLFXaRR}gYO{q+k5l&J>YHRR3tijgR6&VpuqfzjKoO8Jb2L+b}o3)nQ zcnC68*J0+`d1zo_C?d?5I?oxHg%IES-(Rj>vwUdY;J9e#GhZ-CBvR%?D+G@%>=D3u z?11gG40r&OQdNrrrU}YTR&fdAkxfki8sqxPy~Rx%TtM(s)QXD4>CQoFr=YXC+6gZY zojYyTFaGet{-cMn7%)3CtG>SBy^R|}f`XUKSr8m#`3@pEm}3ZmH83D>V0QNYLx(T} zh%LE?7Zj?yU3dFB&ZtJ7O0HPdxOMkVJa|5F%-F2-46NkE+bWCi-ofM84=q{V&ESym z(0W!z23B{j-}2FhEn6@hiM@KxUA}nc{JD|C3MY;pE2bk|l@=W!aX6lA$KMwE_Vw@I zAAcST3=F_rZCgjj=?iDEd(YhIGcm=Fb(g#^1O+ro`rYI(5rANvguW}fecUA)Z0_p& zRq`R6sf*RIW0kili3GOd^b6=0ofHz?$Mvof$<6zsIFm>zSw+{-3myR=Hi<_nayD`& zUvBN8tch#IY~>>IydO!LWY>quBRpY9Ad{&;d74^HmHPbQ;t^m<`>|4L0eN+#WHv4GF=q;ZQgpf6NkVVsWvb_c zOeGY_m^%8b)Xa0?s;&DT&-m7Nexc`&81YZv`p@#xlBNbc^w*5t20}wZW8z}dv$M0Z z2V&t0PX-Te&z`>kZ(z*K9*D_4nDBV@v!DMW~9FS@}6Mf{&)k`;$)8hRDB91u{>BNDjUl?RbYYF((n-~!>VD;F@r^sj z51#4l>Y6flVp3}S>o2@}>;RT6-;7O&#)KCq=bTQNg?0oog4zE#j{-UvtUaZk^`;fCDhzj4MpN4<32*?e%BRUrbI)f_}!d z3DYJ{!1Rj??EH-~>Ryo{~pSL1iZrH_^DH0{KOp4^EP$ zWo7SeTo1N|Gv{DF9q(mYID7Vwe)X$$A8lB+jj4GX6@sc zzTCWH8(vK_cgC!!D7V`I$&rfKVCUS$BJ7!iT_TOWwgZB&Zglq4={xrA`N6M#zJBYb zwA8e*g(F4`D;zpF4-aOMqlcz8>kC)&yExoW??Q`lWq<6Ei;L0LZ~I394RQ)ZzP2JU zkjF|ftSpK)ftKI z{El2t9*4^5mNU4YbM$lB}GrwURC(r5%bX@JnQUP7-SS52Hp&C^aYD(AM*Gps80gJh?`G#qFc=00M4#{5-nUtK8oNP?Hh!4N-6;b^B1BMg~!#@GQ z;z#^*!MlIQ#KcXUGOfGpkZ?7VFl&xgO;vsW{(dPbY50eg7j3OASOMgZ%cy_=wJxYg z-mpUa;}J#jym4SvqmYZ0Nn#*83043rJvr(HUQNi-c+>lm)s7brM2_|&3M4t2MwBoF zCV_GwflOR5CFg*i$hd#NV#(ZsJiIKjs-m)w>!IfMcD%r@zP1Wm6B_q3csV^Di>4iK ze2(Z{UCge_DKuUq*icv3&vh#j5*F<5?-v>x92FOX*X@1u)(5D|#{Ab83c- z)S(o^W6xRvtn|dJ7q~lQtMv3FcD$&=IhA9+pnFhaE13R$`_tP(1u{G=BsM16F-i@j zv9QW<>Qo7Zp7TpGRh9$^Ac|Js-6}XXx_iurM;Hr)k=(%S%x`}6i;L#YIDhfV>9a+@ zd-ct2yZ1i#$+eF^yb1u+foD!Rceg^fqz3d_W2b8bslH9T4(#rrNdqQ#tcQ}Cr(ocE zXoOPY60ObKUjk@~B1@;Q{hS^1=(|9ZMy92t56&3~ z&Si5JNGpipV&w1M(oHQ*M^Bx&Q(7{9^aEo?8}D|E2n(AscH$3y^^4O*=Pz8nRFE?S zO8&vM%CNyhMimanTPAlM*q4!-`oZRnm{eUnYi@^YXK<-`+wR>fEic6jH?bpmS!LCZ zy}LJT-5ebmHF>PDTMi!I9+Z{+wJ-eR$ihM__B?p}==Qz4ckJ1P&A6wIoAB70htrac z_K`SwN_yFjBSsEJ3m-pvPCHgH{3K3I6f2X&poX>rSm{*;Qw?{AY?Yp=FDdPK@j&Eg zKcYaAqiIA5Ltqjp2NKA{1ygbk=!qimuMxuTGY3RiuX#RNyFRvs(A_q zu2(f8ThZDaX}=bn(|*oY5rrPnDH#$>QJBB(Utr)RloFd2!^e7k3wWN6 z6wM755;TlbuTUSWdOK`Zfh0CdnAub#flco%$?AguLgSH&pkVy-!JWHMZenVWo49QM zW&mEq$rjmk$KaKD3)d{Eyi>9Dy^oIXKNgpi zkT-nDlm#<$hvo(Z2kgiBS*g)6;qG>L)(^ zP;7h@CaGsHnzDE6fsfX0U;X&v==c~!kUsGV(GNYb^vtm$5Kf&ncKEmgKR@}~IW043 z;fk3rKmX33Uj6XgiK2w0_=1sn1*7vuju{#j71sU6Nj|WoFs@>ITig0KKdP&3oVR3J zR7^NBoV#>7Ho)2N_LkL8EDev2P!e=`n*ILScYhy~7~TEij;^lH{_|7uNpWg5jmD=B z#(_IR^uDAd`2Vj2B&iv18k43L`ZmX)6m17sBle$eY8sp~fVb6ytCx<)RyO_m2^jt! zsu{-mSehi6iFGB@sv?s7%IBV;RPl!;fq`at!H}Uta!M-7@7ygredgRZ{`33)_b0!8 zWYsbZNY5D>N2+I(nvc+k!3Mb}Bv#~J>5%{>14k}%9+9V&l_O1MQ%7*(!>`uJ7cGkp z_O^X#Oh6!+I3+=fD`Mg4pHahx56KyXX}*uPZ++{7_wZiaHA|LZdR=lCMeM9StRVlh zpL~|A_=9?8y7RTa@WM*~)+}8%cEm``^j*1e{S3B++`sqT4eLuvN>dV($Bh^z95vb2 z*3@m-wizen(BP1xCytAj!FDv*EvlljV$=3*>U1PpH8CM^`NAdN`oaHw>CM-sPn=v+ zUq5I1tbtis7p`5QRJl-IUiQOZ{URnJqW#{zhNi}ww~F!C*cA(wES)o7z;J7Uw`XGY z=%QKkuxRw+?B|#6P+$67$ zi$1h~e8Xd(!nBjnH#YTFNHqadLQ+CdU4(}2DM_$G`liRVl%82!Z0RwSswo?MP4#*& zVbwVD`r1j4<4G2^qaq0+XI>aq z@SB=QDKep=#)aokc$;Hm0r}#tDjcM*wow4#FUVfB$y=c zYGGRgNbx_s``CBC{=dON{+Kt%Hr!{9ox=v!k9=xnVp6>GHmj=(`)FS}e+}=Td|=Xu z&KeUF=^)5i8CO=z+(UaKmbWZK~%?l_|De% z-q@5iC?hs8#=x-ex1f;VwV!@4H7n)7j>E?eojJ1q6gIX_PEQ;>ENAW0D~FBFb3O)c z%r97f97oRi6Bn;sxE>N7GIq)+@jA!})5c&MHf*_b?s!p3YC?c(&d8dm+w$RiTm9&z znGW^ipIH;1gavWz&B8`qba}y#4_YpFNAgwDjR$K^y%!&=KvMozURHYR__BrWcty_Ec@5PiKT>kBQ-*Lq0gSQN0^GUq?dehGBPk-X^ zh!A;?vhdJQ?5~Y=!MhLc$24eMbj;GZ^Zot0pQ;u-^s}S0v%bC#spHj?@zJsQxkHBM z;fx zNU3pw$>)@$C=`whK9h42k}0C5MgUV#!V?1*$VzM^aG0q~Axy2lDMepsd?`JFq@u>f zG*44UD^4rDCa9IkDjA!4`#;`yNSz=SissvDroKv-GG-?7KqjP|rxIBbu`w*d-6-M2O zLakaI)t1?$*T(IG@^X{2 z%%3=eBSy7wr`RY`6`+)?FbR#MSCK?$WfD<{kmIdn%1uP)d8#f}{HA)fI;MmYOkY$P zeAe8!T70*-tWTfL_I6{p)4@Xre&X2&SA1e&cy#x=-K`t;)>hU;$4B3|b`uZmqw9dg ztRJS(w{6@ze)nHz|mX$2lNV+?*y1I^;O&{#Q(n!o0Ub}Ro zuC7637#bOdS5t0YzkB@D(J>yUBVANr|H9Xv#oj)=3lfv!uonXh>=NCp?9w`ai}ApD zB!Eqb+K96y=Xzpyhz0}=NzEtoy z%fP=kd0jXymw6p6zKETZIcom;df{-maqA9Z1-N#xr??(#A&??Bo*^-CfdF>;MSum8 z961|gj^QRNy}-cgF#%Hmlspp&L$s0}K~+t6wrX!VX0HIS+4zDAG}Oyxl-&Ux%m!BA zxgj(-c+S)rKYQUfXN%6qMaM3kyC5t$q`s+ zE}lPm{_NRH7u|Tw3Uy75*srFpu`wrWKz?36ha@E?9XNL6>W%BiPMpB{M{a_Qsl$xa z^x0Eq{O+}v@Iv{?<0ecPW9(^Z@^Ku5@v(8AdHkuNdHMLeEH>?l2n~ymiw_Sm{+gwJ zKpr0*6B7}MKj8lH{da5X>xxTDjQ#7i{l!~@bE;jytzOAA{K$)sRZ>cV7@0ugqmEK2 zfmihs2^A&0siVM!wO=%yFP8ZX~Wbop*G9EWjV`w4Rfqj6lMj+#9o!|4FgE?-^o)#- zO3g~cRyM*U+e;v^WODEJL$!4c?H!%J{qbvofzB`4;+v&d`MZ0|{=2sx!tU5^6G=$a_zvqD zIW~XplBsXLxbCeN*A32hPQ*(IgocM=%WwR{9_}S~E5=M7@x8DA@189OFQ2=blpG%r zXu0}^cxozZcWpk@*wEOwr0bX8f6+g{mEpBJogJM!Htv1q3u`0eBD=TKAx|!_;XY^4 z6l14vOpURF#N^hOxfo0g>?nD6Q*5(yF;qMv31DYj5ttf7??qOTi6}ZrjPEw16-?R+ z&BOFtcJ9OL$ym}t-l1_+W~gfRcne&jBM(`Fd(i4O{w8s!O!5W^AO}@2&$#kS3Qiqps0seZu9GARC`qJc@RP?Qa zL~i8{Ku-eo#Mde%t_Q!zDsVi3y-kc9Ih&R2Ex8qsdg8+wN>y69L*H9xG;a?>O_DyO zWaZ**R+@%_sj0qk>wB9{9y~F0)Ub?!S>9YVyNEsAF}iSnncsEQ^{p+fArWEyv0<_D z>BZbld%t%(t9RLy#`r5xWMro;STQRlEio-SrTF?CESB8A<8a1+R6MpXvOjX*^sQ^R zqhq29M&*U#p={SDEhFXl;WH)0Wk>g&%*h*owU91L_iwyb@8M6a*t_-Mu8n&aE}Plf z)Y5Ge5^fY-BPJv)1pDM9r^c^)ZR3GmM~ZLW#p8q?rXwA`hxeVhbK_22e00I6e7x(J zZc~xr(F3PT@0K0fcOq+0I{4-2)@7V43?BUO=2jNN2_J_aWnmqIS%eo;=tQgYn9j${ zrUnJGsx)(T>U_@RwYT2;VB;3N@iQzU91F<*<;$O4zGxmE|8$K7ga7rn*M)_m%bYEl zKR9R3%qf_Tgo&mkoO!Tc_|o(?D+6!VU8XZ8M)a<(@BZ*-NK6b3Oh-QR_}Y=f^7kG( zeCkY5M0jYRe?Wa>Q)y}0v`OQ?^qHsd_%=(>M^RMDtKRW3(*1Gswo^1&m^K2Gy%~ar zohA=&Opy_hpMLzQ%9?7t0yQ<+(-zpcOtd)dG`2Js-@cuZo)!}5e5HGPN5_@x*YH-# zpn!m^^h|wHt!-^*E)`+Fn#|PnRSOp5y^S1SpV7GJ*1Qk4Y{VML;*z_mNy*ZYhJihH z)ad{C>c9H?^-oGnHaloIpQVf=3z?Lamjwg_CB(*vKSH3&q6D^I#O=@7<7YO$^?s;hEe{++!$Jzjjp%4=+rMKU zUUU_k5a;LbS5Z>FZ}YC}m#&YVGA=4cUL3*lw~WkN+lS=2yJZZQfxslVEXJgT7&epx z2?9d^av?!w6I2B7R**tlxvO!F$uc{QZp@)L{;Ob@X5`p{ho4&hi|_vKrC+~2Xz0Mn zGq8(MAojP}{NDEJs@jDsXRmo;X>_dHoa3Kf`S6wBzJm>I=B}9OHXVsXjqfmHE)O%6 ziyoZ!!cSj)^>=?NE34{*$EW$n6`gdh9liTGrHc!P*qA73eX9%Z z^uhq=vx-Pbv_H>Hs+zti@eA-*KL706hn71XjgZFb$bun55akQcee&9kTTPA4O)af~ zf&Q4695Og}#Lzq^*!9bl)@3yRO;w;usW|BzA{A9UX^5OOT;Oe}B5g)d9eD^LbBs+| zPQu-}-c+V8Ac0`eIC|CiDXb*8FshT&f~z~pCa7Tc_w&P3)!FG8E)vV%_!|tiy?*LT zpA`r}7!VjZcKFCA9(v^B)hl2BuW#eI=2$%091)IJH&#_wpSgJ9@bTlRNr{VR&l6rs z&dzW(wRM|!ZpVvp$B%jdk3vH$G89ol1A~VT8X1ih1E_U z6B4v~PTm>zk=)Cp%qPVN%&!Eb`9W?X{9mJBvm&B>ysDs0W^i((kWIToA6@DoJ=kX} zh98~&OoyVWQWMTwyeXA*xKV}x>-klys-cKd6w-SuXhMMM+QqAHKL0B28T|b5vb%n{ zg+r#yn?7mIlrY!ZA#v4z?ZV|Ze^W)D*1oIi06JKJD6WJ7&jdu#i|S(9cinT;)# z5gk1Gi5*QBh#K)aO4ZSrBqSBRa?yi@Rg=FIR5Z1olUyL!M{+OEvgi-$V26WIA`Ji} z8&;ga`eNUhWe+brd-CFkZ*6}5`!A(sBn`|PP<;LFsUv6c7&V^Y9yPvzjb6iT@tS+< z-`#rZSP{07!FDnJ*!x)dU}+fM4!L{Nek`Kw?8H6_9Rw-8UD0u`BR(+(j^a)g)`Po` z-?)Ao_vFcGiP#T{zz{G!C6BbeeQsR6ed^Ge=9bpk%VvxiGgQi9`NMOuDV*(iuGJ2k>&jzKxKwBffJ|S=-wy#RR|wQTW}=l|oZ7`Bl&tcrS1p?IJdO zd&J?f(>%O8*tq>Q?TkcQ_)D#n_~#4ijQ{H^tFlvrFYk|d1r5d9i6GyTlVkC|UO0HT>JVLdJ$XyH z<2sPg0{_}5ILg|c7m3(cgVn&&^d^-QQbUR{Rb#lLJf8&YUR{g(Ll@wr?~ZVspL zD4l_8yS@L9s;oy5D(u2_nWlyWW&vIPz2;Q#_Y!s1s?0dP&`~TR9AT^ z{apF)mwdb>Dp4lqynF%kEW;(W$$6~NnI<+>S7HckchVe@Bt>JzH7>nL3k@%>b&*!o zEEMxCmbCU0rr@EEB*8PJ&$eJY@bf&3^+f18f>6^_14h2jTu%||P~gd^6@-FU=KTC< zOiZzMvmbV5uQi#mfMbin(xKEvyChLIX|B3@y{GS;c<{L7Qu;ZaH*k3(8c3J*6cegu*p`o#jXzgO`{M?SlZ+~xK z({B=lS=}esosq+sT)(U{XngXF(P*ao(MdPgTpDWCYqq4^C!BD7FQe&H&nYApug(&N zq}L+|IpUI{Wg*@O*nyHn@4bv(_p&NJjh<`n;T{~Ik^Eftq`(kY^7nLn``N~GalQ_; zeklH7c47bKVi8DUqbc0@@cjN3b(O&H2qkx1OvF6bG^09E#&Blcp6jXT`rm;I+IZv# zm5GF$oxf=AV)?&ULR+6NUe%8RHe^2@H-%%gBs7yWaAU_thex>!Sum64nVXEiE4}^n zTLqnNv{(IQT@)FB#lS&5*m0jg^Sk&X0t%0HUNHzZdcR=lryk_L*f5Yw&D(A2FKEeB zA&>bc!0edg_mvDo{Q3B-fs_Jrgucx}PYn}EAtLL<(KuH<5&lKUA~Dez+Tg*bF=j=H z2ncrh##zRzpe_Jgw|Sx3aV{+)m!T>tfCMFqg!m&f`iti!*p>H&cmpY9%qE8Q&TZA*O&ajrHOjz1^qWvmS5JWdLMk}; zoAlrwDlxBjx9He+WdC@g#wi_MJ0)E?TcBzrZ+HSWcGv5q5z6P!>4Khc7UFR2m!3mi z&wFZW1E4`Tg?cwz*AeutTcaf_8jkVe@FQ9-`u(=e2d|%>4)>5nfpWjMp-z&=6->{>RgUg-U7aY6k05Gjl65q4%!fZv#}cRD1hqvhqrr0&wbua$LdXH7=2Th&YDF z|5QH}#YZgjfjhBKZOQl?sG%Q3GWgSHIF+D>tqFtEJ^P|{s{j2SmX#!9+BT?M>k`v! z2+4|8B1AolKd4e9dEeNq*8g9!RrO8!mH%nDLW@?$z7u!;e~LEme~VU9lQP#(i2P+q zYZh+Ziv=z}mK<8+e?8G*!(jsfmWgtv=gB+S4#J}hdcTc=j6*Pi`3$UrR?H{s8jK;D zic^S|v*nUBg7u-19Ef=5)`Yrwrx8`5DSv?Mgo=JSbD#h~ruqCRWrQ?Ld>Pyip>OK2 z44^3(B(V``I%>SrSRwy`58ati7vg0V=nVtHL7O0jPJ=I{HA{ott*U8z0AuHhy`8|3 zkT{b_{)b{viFDD&3H5+N;C#LBiQRr;wkf-LGspb1^!WX6VffmFsqs~!ggME->VJ}U zia*QHFfaXD?s1;jNKzD-VS>@TzDk^#Qt<%^HysRazU{Xz^58@dy>8utpcG%yr@5$X|$OihUPGDXxLB>ukVds^$(iHv5gKJBY70= z15f`I-POrVrdHsirRdK9XBmti$3(v`Wo}RV)c}5;9=5vT_pU@tqMh#W@A{P!n|9fT zEs#uyTaR0}6yZoK|IB2|Puky|r*0k-m5KS9AnOhw!8(3H1pHBz(0&E{{=XYnJxcT# zkw!01ZY~Bj>xmf0&Q0qF3@sK)40BqF%HQ%zSdxVEW9~dqW6y16`W=JhuMnMK75ynC zxRmJ~R0OcCq^Ox)@_)<`#*?N~;lYVIr(H_GZ(p7}QTuh|%SBLB@usfBRR>xNQOX(( zah9HEM8#YUFZHt|f2f0GfX?lx8WHCq|-@I zCrj}xK#%vPB$D(YpNr-<}NK0#Qaj2B@{Bga_QaMV~WtW=Bp}O0&#!S)fn=c?7r9us-p!A@@X&Fi#vU!Q0@qALo?N?W6L4>`Z zVt$KE&L&OO*Q#^Phopl;YLMAD?eF@C$!t10>UM8U>G^aN8L6e})#j{1iY8tlr9aNr zq5fwsWdr*jDi!*93rTvubbq)BDUfyN8K@;>WGCw=P_;JjbfB(O>#S4(Q#_r|QVNJP zMYWl*#9Wyj=5hzn-VFBmCSt_$K|iAvHO;`>BG&=lfiUDAz0%(l5)KJ5QE^w$6?HB`70`4Ug$hJO?j|GCe=ia@A?ZSdsAk zWsov6fslif=cnmN>Z3h#NMGw1h^7Vv6!$O0S7H3B|D~K|;@9iw$Ehnc& z2h}YaL3jf@!E|k8O^wkGA!1=NNF_HpGNBHd=X1Cc02~&1YF^B_x9_*wJ{28+YC)64Qh_iCb#>wOWpSdZl5U_FNp-22$ zL2gM~Fr>wL1%7U{RO8rGTn?CpZqhu%NMLhwv53y$6`NRGP?i+*yZxBnwJu2A&%rjF zI56I1vonpOUqsCd?vHNA_tpSfTjGKmIT>MLU%L-?U!cmQA0jjvN{wYHaYKA16i1i+ zsEfQd;?Ik~Tv_=a;MeQ4VZCvF{ra}QT@4s1O#hyBH>_xo&mbcev6dv(ds8h#E>lf` zQ?da{xyRR)8%B02XgXr<5&V*;8Yn^eQg9?QI}3;Qxu;OKASgC2A&XW5zp|aQ78>Qt zE8u-EY-s>Xm zNkw$3jd#0pg6cZi1XSW>6d&1PR^%Dkc-h7-x8XG`X!1WKM(!bsCsce8z-Khls49W` zn82zzT($?`I#;J7ARU(HjL1#--l3MyVE6CCw{f*y)6HO8poY=7*mRqZoSA4v{2)S2 zBKn}Sc#iaUkQ^7=jI*oiU1VGup@>tzLJ)H1p+6!z^H7fnmz?{Lo3(wkRAIlXZ_O<8 zutCqx%EK~X5sGw9H>TnBK(l|*#3HEgs0C5vLX02ZFDp6Q?a^7NLa6CJI2rX5-RoDz z2FC|wgpS5;POKos6dN~n--8M;Ygj< z(i^v|`)vH(YAQO8tu$Q)z5-;=dS41g0M3z`tLZ)imbGb_(a~+lTmz{>LA+`Q*OB;e zeR|tsFhqtoi~ZL;8qz(rIO4{Y40FbaT;Yx6y3{lSDI-=^Q8xxCIZ9dxAgUM z0O{>Fvxdb$x`a4|We+(7Mi}+mDF{XmC30e@GdSYb&O7if3LKbNKS}L7+crm04_W+` z6NCCCre?1?NxEm($5-UqM@~Q3DmcR0<0zCB-r=Mv#1~>p&dPdkWEC3bClQALxUa1q%ebU+w6zUxlv zp-KC#q-bcC8y=p!(X@m#ej;U$(Y%qA5fq}}lUuwX-s}4@A-NgQc4uY}l{%9n0JO znNQ6WCy@PDL=uh^ZaD%~jC0W7#C;T`p*yGydbXMt|H~;8H~qzetN5?R?UxTUwjjP> z$q#$cGO@BQ!sc6lt8MK!Xf*LLoYQ2Gj4S6JeyjKB_VXbXBw#CW#p8YGzJWobS&q!O z5e-~GYuOVgm)U|*MfNwuVHNP0wvyVWVqg^7|67F13ZeoRB6~ixbuh>iGRU`WIx)x6 z@Ed0?F9~l@OK&sBKAp!Yu=gi+(1vl-PD_*L(lPv7W#VMwq}6#UO^6giT4}p&@G3a-=)8|exx(0$cx ze*HE6540u{C~;vmS>%Ro8N8!DMbJc|+ZXCAt5alXox$tI=M*AyOFY9|-~Cks!Ld1F z>T=m!sqA8b22LjJa?xs*vxDz#!9HAuk~D^pcY+m#5ov}ys$UxwRpb6Q+}qER@-{y= z%f!TEg|KeI>i#DY!RUi+z?o_a+75J;p?fpfuUp)o+`II^2Or}&CnQPb?0kW%Zg&U&6>x&^*=l&ziV3xO0EryjEs$CoG;U46SV|ipfVeT5FYLYQjs6# zg9kA@mHKv$_SQBwzmVMA`o6{JK2A}#JW7B_gm0eY1FTEEHtcT(cYpkXg+xT;(JJ8c z*vN?oLW;@Q%jsBr-z#_~vjl6b`$t-$QpJ4EX8OI}9-fj>q$s`^HV2g`D7YWs;VAqO z;&JQl=Ac#Xd>vs--Y`b?phqgdQRPWMoh*@vp41WC*lD zRe%mBvU+|{&z!GTkuu<86A4)EB&DMq(6%({e0aOu5W0AQG#m^Jq@^uU+buvB0X1!G z`#rzi3@$>62YPQ)eZwU-JIoL9Aw}9C#tEq*;4;vI3z;Q{%;bY&BB?Oac%Q@i?@C%#yiRxx)#Wr( z$zp+v;b&PA&dVP6sGd%==M7Exi*m6bn*us~q;`IiuMG{0plT(}-BLnRiXJosIb-Un?olA;$3a({TrHP?KvF2Q+-|KheP#vW}e0)Ds(DSBMEz+3Ai+sLQ%3 z-DIn(qHvPBHG#0zwXPk5oV;l}#~;td*^z)oRtH};8=+=F{}aBa>tBV_AjNnJ*cTjn z$PWbww1ppn$Oq^+YA&K8@nAtRuiYevi^xU;HWAn)r#x`|m*3&1Q?NoSxs70*{>UeLBQxD71z;T3W^Algk!od)l1Sgs&thqM%;qX2+Rab-v z%`Motn>DFqt`~*}vzR%!#zM+|9ilAHm6g)#@*7q0LMb+O&gprH?_;ta^k!F=8C78Pvjg@zE(|LpFq)H01k;-zm4mKrBU#p6lHdEuA zUytVpvwSpv$J^W43AsJ(wR&y9xyfj2tGkaa4XeSy2eHD+xk3FJsh_>hV93bG5JzE9 zkfklmt<=s&jMqELDk~jLg4k@9*5;bsMhlHz|9wi{-`81&pJC_d2`UKpa>O$d^ewEZ zYy{cyLmjP6cUwK2x?XmYM^gjg;6Lr{fvIV8n4R0*PQEEtxWxG1Ph`Fdc-^W3@^YTk^DjDkKnObT^%J~w9fU!(7A2v}5F>u=!XQVB|&&m$YCv*g7L z3T7ke3yjgXN9J}Dg;V`rVhsfzm1OjmRw3%lb8qkIGnbX{d!8k9J-2G1u_kv^7+Z9# zB9kPm$)&+_(scjU{XMzX`em=PFsH@JDn(jFm5dG4zirKKD9roD zK{XdGLpe}KVuWPIK1OWh1GZ=QAX(@!u;}68QL4pK_+Y4Yhl2VUG3vIjym;4CranuZ z>jUmMg%tWQU~~Gri8gg#cG*8X982z)6DjY8!o6}f%tG40$rtCn^=%Gi<%Iix@)F|u zW+xDsKfUEQKyH(dIU5EIk;H!orAG$U>EUiKi++ce9+@DLu2oE|{bd)LUPcm&{c+vb=Ag`#R8%L>oyE=ADq#K#(fT=jBIVV7tm9Z*ZNJn zC44=Vv)JtBSCqq_`8@wS;Rud*^Jl2epj!d~c?DVlKMwr=RZC*xB=s zhGyERtw1Dw*Zn?y$YqQpESVx+uOk*r@<~Mltxw>AeyBPisZe#mU?KXV@N4;vOl>h2 z7ITO1{nUogJ~ee6-dDqpJ2t_!m>$q_ck9Xuw6OU86-K1)klrtabXB?c5N! zePJHJq5XnS=$0DuS$i%$T}YG++KUVrqRf6#w*?|(X$+e0$AW~{+97=oL_9dCqjzy$VCqgAC<8_9093iVIw<_vLN+dfLBV zt~YJ^K9A3pm(ShQ``h0!u+byEc6uG&^|G01mivVN*>ekWfjxf7+J*1t>)*BX%M}}}x#i>oh^s>=W=;TJe+(V8J!OnLW|LE(l=?zvKy&oG3+nkwo>fbg{z z)u9ARDICpt73&wXndh+W;(KS`LL0%2?sl6N0H1T>?;l;=&Z=^_zYI%JV8z*`y&4Gu z>~M}`$?2?MsIdF5#Qp-2j07FVWQ#MUCXyIT)IX=6qTonfm?j4$r}_!nnvmM#PB01~ z6zfT>dfxpjhJE-FuX+s?WS0F|b^NhxK8zU3YxsSv!J?7p^#MQ{&ylLwAyWdP2>tTQ zSN^Yrau}4r>AZ%CL~qOp#vI~5tp~0?n(*V zXnK@iZ3E+p7MBt{!lHvj;ipKW0e!yiMN?Ei)ougl&B{Tr8;w6#-Sg+bbDzYOrNCrV zT0o*UJP0oa8?`jHDmdEWrwV0s^3MljKhgt=*O%8>c6KP>l}ef9?g;zK7jEmzt?AtX zB6R};K8et|`T5SLrP`=3l%&*R75 z;};u!fqMd0I$F2y6DhYB7t0N9YkMU@VXbDw--FA_=>|^$hHv2!y%6l8Zk0JHxvRS~ zy|yOD7FBou-lUSrA9vS}A$^boC^3x>4DQE6C1o3ef(OnwYLS$atkCdTJ;m=A z$_gJdgn=LAnbbF<`2N`XuLc$jJN1Go+&;GU5eE9_uqBxMmwnV0*f;}mfpAic$G>0k!+eE7>7Uqz zl=wswe6BJCheTNEDF*xlWa~Rr)t@ozO@F#z50}bXZ%L&Y%8prKSrHBli~5RB8R;2| zOTqm3qZ8jAeetnOt;$!5zOrpaBH}nVB6qTbc6pfDVzanXL{bvQ@x=X_vR`bmBonmEp>c*aWqG(OB);2eh#pD#+5r7|$U}1qPgANm+tH#Kn z!NT>im@4L&f)i&E*|@Z*)jAiS51;hfjSJOXAw9IZYELlMn_mlUVY|IX_LI7@xcAs* za_Ss>t~F;1%Fj(Ic_6~kmyc=IZJ!sg^m(&#-5VO3xOlj>P6+=He|YffIyaKHaNuN8 z$HOgTD=`)zjHqr4&vI~ZaCCASH1m9$ILs*i*lOZwy43yMccmNn36rl3$rrPxqNAar zVOaz!wb>2Clf0-^y?zYX_V3wU?xyT+(|o}{vxRLsl1Wv4tH^S$=uB6qa>+sA(hPJF zyJ_`fF$0dC*^dsJ6UYdYkd5)S2V;rtVkffK)v=K-uB^?i1@3P)N#l8gYB~B7a}dKoWV^Jg{@2~(CBb2y&Qz_@x$*76lYbo zl_Db}GuXeo_*!lfy=k>f{B8)oT?$!TzW**U-`?IB%GWLQc$fd{3`gDNg3Lc?vk|+hH#$6twcYwR`rTNI>m$1*N6|!jl$ObiJ{9rw zgpCaWdUSTP)1arq^F}i;U0M3uV83D-cSmmS&eWS9rPrCl%65kD?5wit zQ1^*@{*@M#&wtJhcP&?IIWwrZaU9S9tFY@oZNJjx5>04eXwu83PMG%+35)+ zn|fbKR~C&Fv@a2gb0~%hhWikWygL}ewv(}CY*NRk*=9DiLZvJN1A|V7=diXrdm=&J zi+8gNp}!~Xl^iP7KmU=CA^!XzfP_(Qe36Mmt&+Rdyx|Lv6bG{;fP_UQn#;4}_rBgtL{zl?ZU|G03t>KT z&F<)WJ-f(we|hh?-(q12I4Vom=_C6jZ}W_P&hPQ`&+Wv>M_sSZWme!CkIAUvg&@V{ z`R-Pky9+dV%vH8P%3J^3{VS8W_oAFSN~^-)_GOh(aw&&X)ClC;BD z_ocAFLR>;)zM^84heXK#?{FKIwWzD=NNs^#mfzF<^Sp+hmp#9?H4pXAwU3e^{!L|i?l8;GLfdnZD-b`?z?a+UY43G9Z=h1|z93>Z z3bFBhYXUyZvqHbMv3W^=udt-#?IEI1bfeWTFE3Aso5pc%DG+yb5mzB2`cDz8^S{e; zbm1o{v%LGQh31E)7tFlF30*-SIX2Cc(o#pqjYh}=B&g1EBEbJuet6iRh%=K}C{pqJ z=G0F`7>3fQd$l6aj~j@l_L;tpd_%ZRS#7wxiGc)jkx&Q!)79CFrr;lok)1;fP3N7 zOZ;j@nQbEL?fb9b;Ju>k6;X+>a9 zo)Z}aIsx`)5`+?0XH{l+`_gD?iYc=_pB*sbgKq4c^vEH?fn@KA-UK=*tH1tRUo(ZJ zebj>$_nvSI9S@VnvXj1a@a0(VQ+|q47I^v>~1|hX!(Ejs=PNRYKdW$Y#y}OxWsI7+HCc`$QlDb zo$k@Am0f5W6qk2XulCm+ZUs*<|Ga|ND*R}#6BxLyL=~gzaS4$}h}N*qnr&l~XNNJC zSa|^tYuHo6REOmMd~5XhoKNg=8^`>u$GI{V^(aR_B(8Re{?qs`bRBtE8C`;%R`c&t zNY%^y2_?KCY`}1~4z}=35x#9mW&XqVbe;VOwwaph&a55LInSJrS2}CQoLN>VIgH=g zY@i8m@4ksi;|wYpCbzasL?mlra4!U!KR5}9t-YnYiU-%d0 z{0(bX%^V-4i&oYGVi@lHaVT-QUlR^SPiA_mO3?tGURp&cpHYQ1ql{7lNT)@f&*V8G z=C&oVpbC?q8QRQDpsoUE-xzDopoH|pNGj?c7<&^V*8T<}30?#d`g`9nd6B-9&G@|h zo33+LhLZXgeU|!Q-p2-9zSD4XL70o;X3RgMQLgbZjq+KOhUL=s8bgiq+cQU2z zzVyqO-ptwzaAz;4!?#K3er?p~`Anqy(N)0H@9IELL_v9+SO~bbB^Kn|sHnK5gW~b| zc2EflBQl>X-)UJ777FnEOw4O7bF=lYC+K{wLk)Of{tfsI!g+Cc23_lz@FbR0#QEP) zR{Hu;5bIF0U`<}0*?Z2q;@7XcgJ1HkQ2ib*-(oXKFP1wjrhE4#LboE|g8&QjS$w|8 zf*h7(QAAil7Q%eaylo#J8d>0^W_@ zzECC7C_=euU2aXbt9LPYEZOfXcAHsj>)1Y!na8iI!T>(o)1`PbQsMGt;Q83}Zdm8t z3MrFb?VZYNUZ3=R=l$ECB~k1FC{LOG+r!p&hsa~o1v)^GDgZ2F7#mEw{<1G10^l`v z*4KYY8t$L|H7mcmqN|jeC8*~Q+*qSF(^Zvo;xE0DQf!=YhxydNt zdo@sDcYVP0`Kg-M`?SJqquGC#AxlnA(81GI7on3^MN{Xo#BPyAIi1rFe5n%j>GQzP zwN^jB=LzLc=m3S$cxe#eVH(usd$@5{V}|S+dSGmiPQ2#RdirqTccz@e;)|@8pIdvX zb*L41+Sr6A;KeKUw-j}`Ow$yR-oM!L~`@izCC!+OF={sm{NiK_*&+A zEg11c3hTj_&{+2VvO(68;*e`Bjxm5BddnWe9fb3 z>ze_Yr%GEtLk0fZJUI9QEL60o;Kg#+{vj^@ zLsfP5@q<%bk?x!s*e(|Gu6}4*0@^-uZ7Lp@rf;Se+0(nhvgvNpt}a= z)MTtpqzRR`aT^=zof$?BX+CCO2luz(Vd{K) z&)zuykU>n*dAbA0*M9t$9+S+U%{uQO^vHvMv{2K z1xU+?9V;^b!hpztVi$#Cj}@fw+v}tcRRqaBbw85uinn;We%>VdL~3eJkrT|y4O7a2 zMDG%=6^FJD=&7%~EH2g|frln^tp!mm%64fz*m7E}ez04aOb)8ArRU4 z&@ElsSbvq>f#~;+s#g~v0%C{MKoXA6xY}Ww3U=q){6&3}S5gMQy4o`gqu1tR-aDMR ztLaW(H#W^6i1B5;>J@3>vye~_LP24^=;3KPit;#Lquu76Q2*+J$n#Z1qO{lPi~I4M zPF3fXB9k;X7gt`;fBVyR(;ZJ2WUz!j-J4&$b#Uqyd2EsnRms$J=2uW#3X?t`rMI{# z2A0R`3U0>6(%S>%;(a4~ERCDA9{mL1xLE_~i*a&VZ7DG}#Z(&Y`B_w0Sf%qgre?>z zm3nyIp||OCKTauq|JAS_3zXUykWQe3ZhJw>SL?M<{_!C zGrr0yPmWCxdLFr2^taTZ9I|BdSPg?k7bZ7e-eMP0Jd``av9;KGk8EtTkM|?BBQ{zN z6ZJHj4A(AKM^h@5!j-`xS@`wYUAI-MUIF|i&d1dzK1=kjVV?(eq&E6yUCTbA-AvkQc zMiy1!^;FdoDvL3EvQLkwf9wL)~}Tv z&}3i;Z0$VXUUJzTjLAlf&DidP7^f%;dv{$aELZ69I3LxBqsI41$Lh8o(3f)qHoBZP za{HsS%2io;tNi=bPE2=A(N8!0jvuD$NS(mBs}u2mXmON!Z?{lVtGX45Rs~id#Lj}P zKpA`D6kc3}4jmQEr>^&d0BR~4ZlAIr9*3i&F4K=tD<$XTCbX2Sg@wlF`-{zXpS+3| z0=xLIyVH%P)wVOrteu!%tfSe1Ohy+6K=;GR=htl+S#UPEdMgU_4Theo=`;-qJGER} zM+IoRIg#1#b~u@}>cQSFe)VTT-O_$@r0y*L$DYPW2u@$HT<)Z$jDNOZe2;xcY;=W4 ztu{GPyvFwU2LNSIz+0R;azL{;bw~E^5wKajTuZt6MZW!gl-q4hGPzFt3;%|8xGcv0 zxcOKC);Ga@IsZ_C5=2@2kKGYt12AYcVQk+d2K#n77GxvjILjofYsRB@$KQhQB9lT_jZ{htdY$IugW2tEJ||&a z_c0hH%Fc!VM z=#Sn8DF-j6x|OG4z0TLbxh8V13=Lxvw!FKU5ETPvJAn3BQ z_uzejLwbJMK$Kc#(8XDXA7yvq0?RADyu8t>XR{m;g*OT*;|>VEN8z`WxtSWXFW>Hk zj>o$r%@q|r?6zbpE~80B==tbx75Ym~7OP^%NP`+=^p}J^M#jQGK>Z9u6_?|M@iN9# z5g{s?gT(QROkuaNcULO3QEjLG0J{6;I9xnPvZj)d(BdSAWm-hR1^Zf$km z3gexmrhwbs*`nBq%sO7Kuh{S8DDoXOE=YZ$pLgTK5{-DAHhBX&O?#Vr8_E_B_dOgo z_rQC#9JDM3rOp1=&<&suR!M%en;`FyMIBo194Gj4CAi~PARu>s4g2f zG&F?Wo$~>WMU-d;QPBGhJ1x&r{ZAOB5{p*~n5kMbFb{zDy z`dGU==rYX~bcBL=w+@}u@3gq#F&n-0=yk0%SjbPiJBdvR%zXGe74#}BKx?(`MpQVN z;jw!>UzWUbhG1o6DUOgf`Bt4X^d=beN$qPiIc_T}d;oA+ zp4x4`EVLdPiIOKJ%IyDhy?$!Qe!mR+Bl1+8SMhmK2kH})*HYth<~I3Nw-GLY6nkDq z2ps0^gXjB+`F^12cz2NGX}5p7A0oVru3Bou5-@dzokxw!5hR*<2U0!QolJ^9KIDLXdJYO=?sw1^Y^(&qRmV+W?&S(4U z6RM{;2B_Dgr%qAr@K2*pY6g|(U6mD);xeq2qaio$Xpg7O6I*wzUGM|^zu0yP)B+tc zqH8?lu%H_6(J{Yj!#`cPO7&m}!zWSjXlZ?KZDc?u?`dYno0;ukQQ2q`-T(6!EqZD< zEVT=L)m;V)snNzX`V6I%sy;prp9tOp6Mds4*E~;O@9XfGAbX|*O_BH|hhYM1mPy7C zId}N4Kiok@7_h|H5irq<@t4kXn#`tMhDtfeNXCMay?3BPaCmg(*N9wgWv~H>OO>d; zVTU>7xSiOkA>8nRgwydC;`8u$=#NrFu+i>w_R{ahQ8YeGkP@b5a%op`XW}tj zLP_u|gdt-(A@xso0tHL{H8L0HoYrR5fR4X*ue^c`F`MA~nODFaX!H4Ge>96Zjg^nb z>tRc5x@Q>IqwAmt2?GHT`BP0#6u&*5VzTP5@NCySdZ}`ja6+sA$i#w!4T;foSk?yq z6Y6h%i-A(y$$o`E)#(^7c?7bxhHlH-$869Wo&+#MXalq!?FqkoL|t%QL9{b2BCG+}BA z9xHIL%SYckhrSB(n#xE4)}0=tMC2lq`2;r=#8k6g!ZRHXyse}W1K*~kEa4j0<=Kdc zN@*r%VE4ZjvC&cGj@vRG=lEMj7C*n22Z~lULQPATex3#WBO-a{Mnx!k63LyCG`g z56i~;FEhQZO<}+5x`5%n5KNX#EPJ&PeEe_GRGe@_NNF)W@wZfa_`d+7@%TnS@*du) zc5;$v%&?cZM27Huh6m{-10@l-%YDY`t4%Wvv_rm1!~2`Po%Ck4?ZW#_I~zv;=##>v z5i3>W$u`H=amB+C#*xvH6{@Aq4l)>+@eoe&A9tPNGyPu6A&Cf41t#I4P&)pg^UhJf^z)35LPV&)|31xQ$g$_E>hIsCrk9um z(&mv2`W>C#x4+pl$RY$}CNk}q!atAr6Dn6l7@E#}c&8G-4l@>X;~lkh=7*?$S6 z@ZerMcirfF;}57%Zjh&8Q*!^8_dMP8qwjF6#h0qJH#Dqh0P1{f1{h8PN;_F1e9peo zGOc((NQU^BQ}l@YQvM*is%&(#EN-9NAt2tk{TcE=4+|)`n5%!JrnUJ#L(nZ)o+9Xf z$&H%0iu#8N<5;LR{$}I-e}lS@{r?6PaN$b2jz2K=O)!H$K@Dj0Ku)XCiXhUW3c!!c zwW1&UkQm{pM>_|Jj&I11Npe0p*2o`x-o|=>P|zk%`X~v?xQ}V*lvUZ6;(sB!N7R{9 z69I@&Qz>I8>qYYvI^Vx!hU=v85Z_b(IDQs+SdAflKlbW+3KOGy>o*gT#=}+WaNUPv zCI)2v34pj-*tw04rC&Sq^mQ|Z^Luq(UA*5L4N95`Vbb2?cHSK~6)L5FCgM3QEoXXO zBz-rN=*x4%W0=Sk2+|`(8=dG^<`85r=&Y$Zy>VgkEhw^5Oc&g3u`=w9@K(nINj^MZ zymfhhhc#Eq(7su%^0$-2G%8iNj&1$seA|)Resj_Jtd(E%*ZA13>dkPP%E^b5oZP|^ zV!yWGa#*SC{HYImHZ~>suBf?s{O|MlkYTO&`0~h%g8=vTK?IiJ=G4;4N)`qS%U4h@ zkXeZe+#l%SVIY!iXA3x8>iDQ-zU!Gzt3(VY5b5VQkTm1^(m}#XW;qm>n@i9aU&H8hra(3lA_`uMZQQPP$T#RhC zbL!P#QEn~@hT)hUM_9F>?`2Ppft8GUjcNE2oI?4R8Pz2kLD#X``IEr``&Rte@V-`h7BFSP(6w%DrR{W_&9LA>1$;MehMmx?4n2d z)f~i<1?RqVnGM4$4^pphGW6ju;P3{v)t*U6ex^WW39zMmtVvxq##B8arm?BC=7R+CL!8BR^bvb zr46}&auhwm2P&YPmTdJ>f(_`HtgDv6Fp>@ z_2Hw|9E@Wa5y0-Z8a8>K0m;1CuM$<%WzBA{k{_D>#Wqs1<5% znY?i222)vX7rS~_>q>U}1%`01>3*}+u=uOEOV+kqypPD%LW-KK zhLFKUL_lYY>SAMK{KaHN2DD1^!`y9h@@1|pOfSk89N{krG+6Ho)|DY}Dp>vqBEb^y zJ0=#K^M23Cgbsod^I}gLqG-l>J$EjTW(TV!y%*XFj0fiE*SkJQ@5MaKKqc7{ReO?C zvXZgA2^+Ku>#jFmOY;Wr4LUp@--3L-*X~eQetbN70^MC}uztI}`r&fU8# z(i66}wSSmsdJml2O8?5zdAaJn#{HXsK*BAQfv1Dbi@bUiSVtYOm&W4oActIFWoh+k z)@2+3(9>1Z%gjg!oxQI#Y~%E|C)cR&Dr<0V>nl-l|H3nLhzzSMGHl?vatfcTbRkmbDFu~!k>E$T#UphzQe=8Fy#!u`v z6!~4Vf&TF4OU^G_Z|Ix;D9W`HBZQHDmotxUP{1f5A8<^||NhTxl$9Jc6xjKG_k_+2 z@4Of1!iQy?fzsjJKVlY!1`wzwoh(D*OYTTS?m&;;V%OUePBx=a;Pu-94`UA=u;^ta z&Fz9jL|qrPiDE}Xq!P1zAft;_hU=}6ycZQS!ymWvnTwfc z`l_#b*IRGx-Lw<)aHq>~T ze3E58MFWp0BE;6J`fBKrqdP`5nD&7>y3i3OS|NLAdwXK9r1#jD1+ajM@4sPfe&BF- zYh<!r%Wkt=<>S4XKsuRxea)Sh5#o&O$V2|r9NH&`T^$#4VY(pIPeU<` z-Mx5Ki-R{muNaQg2X5**PaWFzd3m2sbA7`#UjytemASA672j4aA=}&JFB>mg@@wpl z-!G(A9(nB1KPhFGv5~=hTqVil#$Tn(1ZMGqxD8r zpeVWC!Im>eQt6?D>c~H(YS@(!=|!`EhvKy)>CbcOzkV z4O=41#c1)oppzD_3_%okjwqQbSv_34n;TvAuUd>SS0d@Bei6Rhn9Qu$E_^w5Qa>|W5a52GI5y7b^*l$g`F>cZ3L+LL z88oR$LtXh4lH8L|s|p^@SbPVQltE#d@S{H zF=llQ2t@Mz8SHf#MaXa7vs0v3tMu-GxGW+Oi_v%JMuTZ{LCbC~m&yKc^jDBz(+66s zum5A|bSM%Fs?GDgjIq_>_KC7(H}eD^)$T@_mc&fc=kVK4UtPi}H-LGZMH;~avqa0^L2+Z(oet#SN(>7zmKsp=#4gAMSjsnDA zMnN|&giS{wlqtiQK9GUe7ZOxUY;rE0jP?leX&7Yij=``EZ(WJ3C zLQI*$R2irhGydA0-&l-}X)t&pVwJ^w&j>$bIvGqpLY%zhG z$iisZTju!a|LtLee&CHz6$zGAg?X zwcq5*;_@oH-Afv&Z@XuB`5PCX9()*R(B*jd98JGF=pyK{eP2nyZKo?OJqC=ZulT@h z?}v0WK1I^#JpxU>uS2=*^!rTywin$_=D()$CJCbXuJ;+@-MJvdQeS*Yzu5fWI0RYi zJTl*ZERFuXEkYHlkLSiH7!+ucx286O7Ehv}r5A*}cdx}$V4^Zr00$#Ace=$DbCb%@ zspNPMs(v3Ex_tXN*Mf9>d`vkMQ&d>AJ2JO1lCY5qy3kQ1`@!OYoX-=8ITm(<%iKdO z(XjE2R|_7UGqcNdq&5)I&F<|x4x^ysPG#=fN6;pyxQLpP8Z{!OsJqzpO9{& z`=`_M721i&=5K%!7#nmZW~OrgjPlr498pJ53Rl4*Dr9x}?Wl4%qh3Cp4LyBYU3d9% zWOGvDmp2xiKxJih6#xs#ojEB6&YcW9-ZI^owAz_A%lWdzq1bJsO~U6pe2%_5T1GcA znuc>puC7Ye?H$TkvhXBX5#eQ-^*Ht0o;g*UU6D5j`V$COtbAtGP>5O@ZG5&+Ji>8c3n+cX-EJ6a(&b^Lru|#1bF?-4H%@jD(DU&-( z#goE%uE!AzXLZv&i_Ls6H2HO#UG_-&KF%`d*2yfxY_n@}vw9mEnJN%LvGEggmZV9A z+Z4%%{d5c}}T_{a!+-yz6;PwmH&)H<$H z+R7|Ej04Q1Y=zlk(wea+tH@dmKFz}Zkt4u>S zT%b=^987_&3D?EkM~z=(?qO*Yc}BVvy%GqKR!9oD{n7zqo$RQc+P~g?@8V5f0)n#B@x+Fr(m(&4u4&&| zKk<@`xkP99w3dBk20P2kq7b9$E@F-ot0Tgj3&V(Fr!WzJ$(FO{H;KW6Yly ze=?(_pkPevYADTdSi z`f<2S$-AaYl6SFa#4BPfkc}bLKp9y@-7XKGzUS1%3SS|2?hyrn9>V6aerMN<^>#rH zcQC)%@g)yGHrvL_6|V4WGf9PZqY$}Z=e_cIOm=QkmQbDDIZ7H#QZGMl8G}|0q3Xpt zx-YSyrGp5$q9)fZA42RJEBVqUCHdYlphBC>%S06#tjkG1WUyFNN?y{p!?!Iy0>DwoJC0(|=KF-HA$lVj z{e05=_uS{W&%n=t$x?)LyYrI90qtt=+po`D(NTNN%lf&Uce8e|U`ObEL@4<1%KQ1P zojQhSfb$aUh!o4u=a9QTUU7D4;Q5q}tIv=%*|(?cfEhON*Azt}san)zXa7){SAKE9-4 zV3tbFfK>p2umz-K9rBs|^yB|!8aOI1f9+$))qmLr9tiezL-1wyBD|<(yP&bUlE$vX zsU2}Mr?QY@#YE5Kr~iL95zfyr8=*e3m=LCkm&&4|-+?S~jE3JVx+=D^`ufubYGU z5%F2}x(;escc_RP9?KS_DB|*M_I+=^x78WP=Meo3%Qq509EAw$KGodgkva>)7QUHF z^@Gg%c^H^@TRNV?H=fb+8#f(gOXPIVVJ zb&zR2zVz9@l4#N0;a}h;AbwX@OY(ileZyB2r##ycr(djGuklz#Y-zE5WoO<;)IcAK zNe=_rbiEb$JwDDpIJ}5T6EDSnuy;}%cyxX5xHjiY;Ynrq;kfsxZ~y5--gt$P)<7SY zcC*ck!=^vcbYKA#_E%64TA|l&G!bzpAOzbo^qW&Y3*vs!yIi@-6@$l#hVmEA!2vaF z5Ls`s2BR={6Qljt?$71rS&u$RIi?JZwOwVuuQ?xrs8`k0lGZz1v41e>@jWEw`Y@9+ z`8(zM91I-h`cN8@^ZaSL?}7G;%gW2|FE{-MBoBt7S2QZM+%7lP_ks#X#0Qps>Ka2@ zM0KU3PiN_;l1+Fqt>Lv4&SzEzPS**s+(N$73kf++WE?`lcf1bn%s6;hyI2kz);eA< z0r2!W+7FZh`&smL{-upJ%c}0hsKmo!*0?=!Y0Y{}_D43%E1#Qao}P{{FbOgLd@(k^ z4wXxF=lk%lUG&s4gkgKw2Qwez%2#$vSrrPCTW@uyRyfn^aFI4reh@UXB0q+c_fL!=UXUXh(<{Zv*J!LMaD)Pgmw^cd2*eB#^pue^2K0 zlJNbYR>(=<_ioW#&;ytd~89qGw(uWgeyV&&_o} zAmPD@Y;0`R`fUzvg%tGs1dAc6OqCQuEe3kZKYq9#-Vd{yP%CEo7j;9gd*zpFgGgy6 z)b7m@&^2Gqj%Qg|gqC@EN857h2(=%MFsk&ctJ>IWYd%p6lSP74nX`m-Jn5{p6t?$h z!=?2%X@f}HyupyWu~eu6EIt9gL!7*v{K)ule_#KN1Yn3aPh{rw`U#piD;(~jGG_?D zTa+QJKLRg1eD7pDtgOVu1c{rYBmYLP23*-gkN)fH0nNk74VKbV4%dTYXEB8zar7_2 zUc$yKhBgG)EACN-#|^BKa#DLqXgZ@@d_?X+gtSQ|9jm6R?l!)`8g$cQO+tLbwxsuC zbjV#ZA5}a}^ZhbU*w8a-v;i+-m~;YBvA-jn#qj0O3Q0;a759AWqS(@8jJ|g5aInR( zvk725GxyWl1-?)kd1AvC!w@R#`09EydCPf0ltF=FfJQ&ryTWz#9fPP@TYO#aD|+Y=1=sb)oE@G)DnLSc zIvE&Y2fk|;86G1g&8eG;d{GB-T58&ihBi+(mC2N(cPgkv-Y>xBC!5V=`Be zM5NuwL4@k=fuctDU-tvD$tpSu7h&n_2c$vt1&Ij>#fq7-%2`RFcNIY`-vy|jclK`< zs|}EJvtN05X|Yk_;jtndzOmQUq=l`j%Ji;v#Ph_nKv7|Z{@&`F?asvzmmfdyH~)Qh zv#KcJiAEfdTDsaUkch^Qq%Y>@u1IygwU9c1&RhkbFFAJ}^M}$9jLi$Q=eH zM=|*<Uu+0#10q>cR3-mljCi*{ORQM&_=2eJw~sv zYkZR5d)%^0m7???IGv-YQH~qC(xvO^{oSaug_CoLV%C<4L!$TujTaF`3HMNx~!OaMaQKO5t~mbyoQ(U z7L(a6f1gj`2k7UqSAVBeaj)88UtG}gimgERPwu+GPWYN)3e&Fd-7Aq!bFQ$PFXRzC zmclqTp`P~1Y!%K4+54o+q#;fboAkCs@i%(!!)2Ev(KQ(K-D!z zQJS}aKriE+KbO#IrlslD{_r6R^|cg!d>ea!hb-;)qNK>*@-m`rqBI4hwmN($8pF2k zou65UT}+txT?HHya&1P6HKWE&%B@WkjVVNsZ=`8TZ(XF(bUq}-O*t8v`6oiJa%?gU zZIZO1#ghdlX{ZP&H9elk)Piy|oBJ(F%&0ykws7U-eRP#k=!c{JGpuZ*v_^PM7eV4= zJ=>U0px$%Ymf)`m);%fveE1j5^P7@d@7_sz&P$IRQ9O%SI8w1A#OT^}$S5g41v_|* zr8?|R3rxq&Krz$`*NK&3(~%Newey@R(yW+YgP$mp)I8fLLK$IYhAnvAa8k&tJGz=Z z{YHixQB*??hoy`B))SQT=LKbWIb4Z}`fOM{*~@LmXB0h`X8)hY`lQ%-J{}!xBsAS? z`0!!^oajb*<@VT4>a#vX+6d88)ZsCay3wuur3vCv1$36aT5Lwn zT@1w-0zobAiflkaL5iyev;R0YX!&ZLG63g)T+12{buCT>`P%?KrG~KMtMMOevMn}S z*JxgqsE<-cLP(U!IZPu)vcECelC1sKHl{aX5^1flRir<0plW{zq0u}Oy>{A*!tP7f zL+}dP!BEpHu3J2Uxqp;&lKF_BzTF@9j7J+yJr?i8(OZW~|IN4Qn^aVtUf;|Y*tC1h zXN^4X`LN%W_Ns)NfTZegS~-2CkF2$|5hR{l(;foex*e}idh4$&#crW+Df()#-xoM$ zcH;RTnLmi0NCCqI{Vbmp3-K<5K9SM6eOzj?Vm5l?WJyRPr>dn)Fj%bE!u}m4lR1A9 z5+t1&H_CEy-1%@MPLu^mA+n&RJ|RY{g#{!I+dYq%VFH6v0SToOx5t*w?<^Ron1O_ydW+&!sGnZRX^xys1G)7%Pu8xUI%=?zg%F^d= zE)sMi7t2R6Z9jQTe4q@V8`D{4?E%|6>4~PQj9$1&i~-VUvEcc!C&h7J=^6!Y8VMnU zl@uvoz4`!67&bavlc5n>qnY9owkR}W=Ae1*mk(Ktn#@1NdZ!GR9e+9HUu=IS=w=B& zEBUugL&TlIhXOT#c ziviZuN|uEAe6$=VTN;+W-yb1FGe$Q}F*0#SE6@lc=T%S|N$8B2MeRr?LNEQ+iQ1-6 zBsw2L4u9ifwWfikt-VSe#nQyr_*}`%!8NYA*A+Jc4~3s^VdKQjea)nO_zpTGwiN%V zQt@lX_KFdRCNv1Qt?C;j3pjE4Qb~2KqFGn0Z8a5#qBc(n+1Z8344=G7V>2jb-2GMR zsPguz##ISSCVO++>Yhty7SB4DDvUNe<+qoI-y*hk;|(5#T@Wh_1$zIC+ebhPS76Th zr>j@M?v+CpFy9z+ z7Rw~Gc*x9?KH(L7%`2|FB0Q6)s^E$kCAPi8iy={>YunP{&Rl^P#Q^oXu{qKI zdc9bYpF1-_p+SOLQ*F@y)yMegnUe7io><(dKY>hvLcK7w|FKXdX>5b`7rN-U_q$g) zYU$&25~bL=E@t@(G!P;_HMK@mN|NFQ$SJ;|dVTD;?CF3_vdKVgDuI)uH7cUFS-aKF z2AdvrJ@x(grnDp^DX;(XPx6m}7Z@H-U)ba*t+q+9*ZsHiRSkgb`z(4|T~uS*Kl7F+ z*U7TB6)f&4P3{E%#sgd<6kB~Wl<>D=ZF%|Mnwq+%ARBLH3b}`z2!Tpd-ixPBnysNZziJ3S{5B>wibpK7I?njv%K9$7-lmW8Jv)_J+Yqkq zn)dNV-)f(V*rVg?j>^@o52$Qz)o_xSR-T?o}JQ>Ox?pIa^53l>=g2mz|NUs zrQ(`(&-d%Cw8Ea+k+UZTRF|5fdxV@;3;wht-|ILjH$}>sY1-hA+7UzX^k$C?)XPzg zkQ*94$M?E9$3DcNdq`U1UCxq04F>QFyQ#y6_4;Z9PrIz{SCwODkrN&%0wEDNA~s(( z9(VUI-o9)s#38jQAAz>bX~R6x!7kJ?{B6>tw6Q&Z5`43@$?j;pzF?Zt5=*QAdS4i* zrT?>3P?0J;JqbNSq)?iTa!5RBPZF6R@{-@Ig=?_Vl&8~{N=S))=7R9pEQJW`28<02 zGO*M}ps3VTd3M;!B_R6M2FUAr*p)fJCMj9-Hf$aL{7I=nc!8$1M2z6{7#&t_!;AdP z16#}Rns-V%2ZQ0TsS@{JNF3vOO3_Q7nL{iRt!7#G^3uZ0TMIe2ptI;?n3frDHTBBl z(=ygh<0F>(6CmjQDEPWKr730TJcQvz@jJd^ArP@}WOg$rJA;?j&&AXG0WVT${zPUi z2kvl^qq%ttjT`fT*yAf+cZeMKz3d0FdXKdh8Ukcea%@T&)sg(-_}Qb61TV{mlM77; zL$UmWd$ICzELEAev-f+?x#`;%pYkCpzrH@Mw`KvWjF>#Zqfe5q=zM%$a<0OvWNX(x zVT*^!nS|~z-mY{!QqLM-ptlVHSxr&9>u8B`<@9p$JRTmZ-WXqNh6kVd?keG+KMqqb{DSJv^gz3z%5hhMm z^p7l2%BXpNqe5k*L?9|TXj(dWO(u%cG#~ZK1H38`a(fdZqPmH5u@)xzJMwvZQC(cO zSs3%xCC@_d9Z`HEGfJ{&>NYXY0=L9WO`P;0J~N3v35^Rryf`!WN_PNVE5%_8!}$7- z)uZNdWZ=IzLI+kj0tY7u@T+!%FD~9&XD*~{c4WrZPTrka$(M zkjO>69QQSHL;aXN$I$Ib$)w^)5{g+Fg>`B6kaWE9n*>y}R#YGCD4c1Ru4roAj_9GI z!(aifPsJSDL5AxiS!+awOCt!9%d^xW3iQ|~%5KUkRTj$-}Wt<^RGil;3)=hm~qH(34IjA7=*!!E9w$K-8fh-Gq{T?R@ z>KbBvmDMmuH9X2?xu{4IC>sw)t2HG`qp4MyrAr-frow{tcT_?-?g`=a44EkhHL@SGhX}hJ4&bb#p{>>d z1T$LsP}NdnIW?BoMSxnDo@-~7zHoPF$nvI$KLq^rqoyvmNH&>nvk*Dw=45jss zo-f)ns)(|(0b7*gt|UGVGa70r$}xnH4Ik6q)SQ|$onnv6xNx?`_4qIOzQum{i?@b;C4N%YxdVd>?WzSLz#q{X=-ij2ZVJ)3$Qu^6A6D|0sm0pJgsKP7hh!Gr;vVlAb=lHUQ-@-B7Dv~;TqsP_P5rPiwe z+_<`*jx0^f4I7zK;4;-#F$TaEzhEnnOI)9F0bPC&;j|7l)!IfO>)YXGGz!Q<#|VY9 zkhQEZ1E)*l;0mPFVcs&4(eK78{mZN=Z=Xay-zUO?>eg+2MR&5PkqG+N#@wn-i(MtL zQI`^Zn5ahkKdHnK+ojAH+|nQQXk232?FA*i6p@tC1jDW|OVs6Aqs^U1)aU%6iyh&4 zoT0U4ssw&D1rne|0hX`2rLx%uK(^2~6QzGGizhb#`9dkozZ}*__bmFJAXdc-PmkzrwvxYbM&&i+1t z4EqNowqS8O2xcGYaZt1_GAk0?QuW@6R3VV{Y)+M@f4`S4jfDa|j+9l#>U}RuidG1P zRd~jy*nxDwuxuga{!CdKDSZ0lmQP|v2l=oI4=j5K%T~$^J^zr>Xe|sMAW5^+Kti()SO!RpJP195h`8ipDff6IIFC+Bh zTYxvy_ZG-dyJUZ6VHZ~NEANa`lS!nlXHaCU+TtPqR1442klwXspa#X=Ox6T~k-Mu( zmH5eV+t4#?q1el$NoKT3BYEz*;P*!7-s57%V8?b>ji4gPZ<$pVhsru!u7H==U& z#f;i)4t{QJLO7oa=wqQAM&pcAk;|9QjE{?m6wsq!p1$e6Zw?po8f9q z^-&$?@#AL=J%V{GzxQYn_px}4BPH|SUK*_hLglL(EvB9=)6(jMq$%5y)Pa3$ z&#}uOK+#J_Q^`w_-a;5^10!)i6GuR#1{tv~G*kDS11TWYN2e*EDQX#R?Kc${ zi()d7&XD`kHh!^1hHm2DmGVQyJA|#QOUEox?gfcbgrrqF|Hq=Job^MZ2|+y5B-Nw< z%)q!AweC)v0Nv5A?Wx$w>aF^8E`>Lqxp0A>`33}y162cv4a#6u7I18>R5AnNNmITd z$fB01?D#C~iRvzXtNcwl1GOfxsgP~O$iL^0lsJ|G^l#l{K*)s-YK;u3n;PtDhOg8Y zbuPMl0YaaL;cW(^z?ErbPf`t0y>*NKab=I6(=*8C1b&3QL8TxBJV4rSqy{!Km&63qqZPNP>U z*C&Lvy&3uIn$mR9w9|+&g-gNk{3`_4U(w2tThXaBngg>`yn$9es6|MBTnkT6*`CbQ zR#YNaFD+!GxZay8t4yx~;~%H|<*_<*gI0>EplKIuTd3F!Oq9M5xMyv+>8lx@g?n7f z=KDNn%LBaIocA992ulbrV08qyQHNdm-=Wj+35=2zA|{OFmwk#qr!T)BrngQt8>-Q) zznb~BW&ZeuF4=W*raZFz8_<6jBh?ipql4N0`nssCqbIyVad2ZcyKqcB(G)%4%cb@E zVmpolrl43+;P82Q6kr}$S7K=jki)Lcvvzn`2Np4jh+TSZC;HPrAGiGXqhI8N?8U>B zp~PmYDiheL;;FmM$}NPE-)hxQiHrTmN^(@Yho!|IcIb6qB8``)Vmo z8>~KjG9AznLcp;a8v5U|kam3pE}e)8X86Y*DFNCf@jkP!T!OL>4(nC+R~xCmuHJ=1 zJ>6U+wQxe${|QjPD}c?Dh~`Bi~tCWjFpE9)yKx}%F}u9;DvZ? z$JX0q;KB4!t4nD~{bm?#@>;^fshals5&X^Sd9EO;Yvxv(%!mt*t7}}{CKSj`fcZ@v zXhjdPJfWn}m*}Aq<4f)nYqgv3Y^z`@;PPt$h97RQ@)BU;%q-NqpeJ7aKH4Q0L_;V( zp%+_F=PCL@dArQ54Um!5(cA$^dZzfHHz`P=2EmTM7vBcDxMj6}#>BzwR+|Z9kYxj& zZ_LsGWM;M0+~{_zIb6M^V&y|7F2A~jIo6Y6G(a7vFUw2{W<%z}n<*tE9(A=|m@Un^ zfXcVCfB2h1PW0VrmEZ8wvD8D#Zu(wvQ3-cExu?-x^mSAhGSA6YjLmC|#*TT-F{hkX6i1}0gFUks^=ady8;$dYAX`(((hYFQ8#f{%xT}lHTATj-%$f{sAOfnK zRG6ZIEFmVlsi7fJ&EKlp=ZnDm`xv{)91TxTCm9JDYgz00^A}*s+O}9h=*J6M;<>e% zIq&9^Zxfvvp^dJ!j;y8Wv|^#=6h&0DG8y<8Wp`~d-wq#P^@d4%U!VzoK_uXJBXQqNwQp{2L?qhi zdf#5{YPSJHh}s{zz>#iH+Guc|0FSwp&23z7uVdF;@AGuQwX(9EWO^jgIF50eU?wS( z7~?da>)nAoA%$}YTBV29Lrs$g{BIjWHM@HjL){LaF9jx2%r8rlKPoul9tg_W>+){8 zUXG;jkyj1DTFWpq)zCzv&m48|!g+4OaNM=|dcJ;>eb0Btgg?upH@5-)F0=?%?e0Wr^3#%Gz>y>PjN z_vk8et3+wknSfDs!5J(idTwnS4W@u4(tRtu zJfb<2lj*+`?QT9D=SB9uW=mIR2Hx)zC>B%4WVn_3bq50$AqVEXC3%T%8ynWY>|H{m zYD8n`WN?#;D~wTUNkUY@#1e*VW+n20T$wf*SReFPx>-0Oo&4xcc5ANgQ;rqlTyxw_ zO^xXIJT#=Fzh3PH@5WO#s6dqL&?fs3`uhcF--1cV?$Ujm{&c}kD<>^$SQ&_lPVjbI z?dSeFYU_kX$iZSR^*2s28$8|S_s^ML zthKV^zVDN4R8&?Cc>d#lzioVyH6QFBa2N1*Y*F$(_Ilq`G%?P9pXz#^jXHCo{j+Lb zWvBUd{U{X(slA7 z36l-$``GFJbV9A;CRqP#|0?V3U;)bF!F0anw>7?eBmDT%)$5UPt@Y(GW==0?xP4WS zv1(?PNJ3Wo!h3D`3H1KN48oRw9(%t&B=Kq6*Sxg)8f__} zGt8aXa8qIqdV&K@LI<1yZIX4;I+D-e$MkJq_cfp-b`gq<>dcDhS!mvuowz;L6*f&$ z)}0X~3cL$ZJ5jf)0zNB*uf_=3Yr@#-)E!ffHPWuJu}mrtEqH6n!4v39=!5T4bRVMR zc~|&gJhuSa)1Urjt3{T#@Rx&}^A*V{;EqYt+Cg9Lv=^xv3I%cnFwIYcCzzdWo91Y3 z2{z6{6E+fCWWS!2Nihf>5&r?eu6V<;+o9l0_X?ztC7OHUA)a;i7Tp z(noyv?dVU@e0cHm;Ovlso`AE_$(dDYx>a`|H90@rVKA(czSF^;zezpL)C69Ua^l6`5u(0d*qU5oe7Ue^8>-k*r5pi5Jo z)$Y<0R{LaGTnIuZw&CH#?EKfh1!8K$b5acn&T}l7K{F4*81z^3{1z?*<^vGBxp{*_ z`4{ltW}d9IjMh4zwxX?@C1Z)=`Hi&(|G8&vxSu{}6KL8sHAW;Z^6@h#8Em|Pb4Ke^ z|De(*)Z^mPkHt?g+Ct^C`&A6L8<3>KiKhZzV$1B-d0Td&f z4JYVq-%C+M`NYJ835oizejs1(mqN3l`WY7B0?2M{p z^w<}&I~dBn6*|Pba@D>VQ*nAZ(n`q2BO5D zwzh+<$v($$M;`2?C-uH2v9cCJf=Gu+UO63sFs}D8AU-NPjf2g%Vf4Ylf7Nf@oACQ} zd5tTn+!eXP)_e;I<8p3R-jIiefniM3RJ)f_)`^v)W~o zTGtXU7fg*8>$kN3(=<+^<6!%$3F%}$v=}$T#{A779v+A&wJ@ayZD37>znU&Vf06=D z*^JP8^d=FA6XK%c8cOkBk-{hXKaoNR86OPEpI89k+V4J*y-bY7bCAV=_dO5_(VP|G z9~=UsM_kPXxqXK}d*=!{f&+*ny2)k7^y6~xl)`#w`bp`^c3q=LJak2`39=3ApUi03 z@FUK5ofzAz8a_?8)ab6*4FXO=fMQqog0a0p}-N=cgDPZF@WZi2>OI%=yFiK|dh zg;b?@=aV}P%WB*y;B2Ur`LM5cC%oVdl~&1fNV;O;xzzT8bCctFkB)6SN z&4F?gE6B>;Jk{L@LozRoM(ii+;5!JH0~M_QRLwA}>&F8C8e8(KEfhI*mZ}nK7-obn#Or-O zmmLhhlxB${B-4B zLUbqbDt={Xs!td3aPzck%a8=y*RRhRGba`{v+fcOv z4dJfd1Q9l;x0;=e_hQe)Dar|g%;JXM;~@G5Z_Vq`+Q0y_&)qV_@o9uMT->kWzfpFc;={<)0)Zhm;L2TkBr1OKgAPr!d z)9S@grODpFc0W!)m-}DfFo@y#mnN1ulwySRjxKuD*(rU1$kF+SdO?9!j+_+Ha~C$Hh)hJpqNjk{RkUT+5C6g zm-|*{xVF^wm}u_-fXlC&VW&|4w94-rQrqdma>u=PZl*dPD!`EZMMXT=0Ld{zKBrGn zZF~v=bOdcnZ+FwE@CHY1&`2s%xqVWilUSI0GP9Cors+2%k|xbISUetdKtf{Y>lyd$ zi~vZ}5glK#AOWB?0!hBmv{1}Gc?;r!Pa}P<@g*+b@|?KSsZGU$5#3Gnas_m88=xTy zGoNgnM#n%r0+T77q)X38-c)y2Bcz@h!3EJc4`(}4D_Sr@nJB8xl zK~WT1wm49Rps`92BMJdWBVw6{PuhO79^JgXx=t(^XzDFL)Y$Bk`=f(>rQLO`kyss} zyxyo|M9HC24+p4_lxQf~pO2;IWQlH2o#dXakG<+Kc;-2DpNP)iRJRMW)8k^skQj=Mde8gkNv*mNmTE#Gd227rrPxyW|UVG zF#~g{oPzB0y7230g1U2P|DX$_04k8hX*^68TNC}9zxH_D7KNL>0f#(cJghmEM;PQv+xPn$+<)0JGqr)NtNm7lzUHs$EnY}+1Z#Q z)0-lJ@S#N$7MEn?=(0%NOL9VYpX-&eS%Y+*7b6!rPuX#hqhC46*4L2=wa2}|5p_XbYIvXt}Sa|k^Fm3h*@!pT)uB6MctZQ;jjj* zD~pH6`)BnBx={{H*i)DyZ~>_VF<97R72({RI@~~>PL4e%ce2n^#i(^zR*i~`6A{4U zQFwMzL_j_W$kotI%SrKuD2*IExOHBwP8mKan!22zNR;By&YI(%LtNe~3EYn5bm|3q zW>9~#Loph6N_u*`*NrH}IJg5Jaai%#o#Zw~?_?*F)Xr8q-XGp>8ymVSF*`8{qsqd1 z^h#2vyPht(I&ONX8LsW$Uf#v&=l%%#9m}YwkST0Izsx{$YIF0#SYqCQcqo~?8`0P& z;7fT|*pC_ZPv6)x(QFUBgO6yG+ln5*wUgYx{ z!94$Qo)`jnOHbPIh5#wuA2mc;@#FY7mVZxCa@$-#Kg=KXd>K-zc;8=bcv0gLAe|6K zaT%voh9%S+toeZ;ibYk(7ZsMTjwu9*mp zk`KPS;TYP%JP)YJkl2PtHN$?%8P2H48uORbpy;~&t}&cqwyOzYh)w!~r;-q${Hbq# z7b`bAH%*wAoV4{YTEm-}d1gbVJ#)>u$W}G~(SyV^5OxL&mi%Oo(7UA?OY#;6oXbC{qv4M}j`%SsMNB?;wC4i6Xv4&-txSz0Lah8kYilAt zs%|JVn!4P(4YSNcYD{b=flQ>K<>$m*PjXQ`B z4SkNF<8$uY(B>|Kl+%s}v$H0X9H8R&7L)rTlP)cE{*t^&77DWw*5k4oOsY7`V_Uau zS8oj!-OL!oM6lS3vpOr8Mnd2lUCDBpDK*X=!Ep1ypObTl*AE z5i#4M2a#1-BC|5KMg=;Y_qqyTm=9*rrZ3}d2)=w!ASWQx5nhQND>s0wJXaWLw@7$d zPX4P)xUQ#6ek!=>i2P%9L!jw`0O(GS)e6Qt5|r7f!QVaHee!;S3g;Z;;{`i{1^<3c zSjo2vu63+ktX}K&QAk$Uiy8KKI-N48Xt}n3hur7y*AL0#%O2WziXZ!_Fi}@E*!ID8 zedt^KIbbphAHv;0(h=U5IwfIOp-^UuY&b_D&`A%oK>B)iYqIuNb$(%E^!S`;Jy!R< zCw5a#UV5;8<#MB(SipVQ;O$U&OiI2F#-sr6K(vMXnZ!Jk7|au)RX=yU0QsNT|LZy) zjWb|IR>1e?utbO z!}(@YJVALPolI%9tzhG!`G9}f0m7v|eMi+t9b$W1-*;ppDMVgTl@JQYAPN8%c-0+v zhK0kLgY`;4`XZI_?dV9a<<;2Z4NDtgzij|m@Mc{r032qkla6W#jaDEeWVeKjSCLqD zxJgA#H=P@Bb;W;o#7AfNkEx0ox%;1PwPK>&F?ldq(Fv7{%8}Uuc)!#UI?@G8A@Cc0 zSNd^`k|9?Tx@6L~1j$;GCEpVXru_5J7wp6*Efv3w_96LSyLkng(>0cUgpok7R>;H@ zv`}y6Pzao{GST?5VtA&Rwmc=#64YcC8uA>XvuGz&QA&9baGbajhz{54_pqT;{q8-+ zpP8Za1Bz8A65{h)Pt7IbnrW|Ch1AV|%J*pJ@x-<33d{6Uqo?QBQ+~0^vm23M6+leV zrjhq^ep8@CW5 zTSjOlqo9<$ZVJRqwzl@r@k525Ka5@E#@7benbdVRTKBisit3$f2(32nwUoBE4bVOw2!Th+} zVg6djNOG}K|I{scY4h?bOhA>3y&-ObyIHUBN4((Yr$<-3Lw&8pAw|jz*+SwDFAwGtMMi+P$3@gyPRc_;{QQ-oFn;-nqN6>F z^=J8h$qop|i11Rt6D^9>^bDR<>}N_4GJ0%af6IB?ewZrl_pGxZjBi$YK&%te`z01G zVe0#tPrOP-NJp;@d07gWDyA)gB9ymPc;$i6v6CU^^PhIyP0i?OO*V2W5pKBVtG#g( zBvoqgiVZV(3m+2Y8kTkQr_yb)U;5FPeLp^^UZ&yrOqQM?X;5s<1y~Iopwe8J-vsj6jd&NXl2X zC6k8SLX3>S(2MdM(nAsOGn@b1KB?=!+~*g`W2=mqd%{qblA;WPJF0-)Oo4Vx8%4v%)a^ZHe}_*1AV^9qk>-sfp3igRtj2O6=cx_|C1nlcr7_ zTQCM|k-)BeX2~bN+O}uQ!LR?p%R_TAV?o z)Ff@fJBo6sy%mq!_F4&XG5q9Zj>5p@+uoU%QBf^{2Xo61Y~0=Q)o;Fd>TuaRZ){xg z)S{`gCnPWa!`>=O%Bz1X4509`h>Tq(ch#4H@enPf!?!)oaW#~nfanMu3@-UnKEM1cS z+%FW8$5t6J_k^J=B}ExGcT@qpo3flyFlw=1Qb|89U!5=`Cp|YM0Epm7M#PaNA`371 zkrEjXNQf%HXC{o8nK&bRkDw@ZM2?B0w5`{Ul#n6G39zn2=@d8F2w8#wT#Kr3lK^^Y zP7Kd@aJT-!y*fR|wliX{v4wP!a6LLR@1%MtTyo%JZ(WB`1uYr-wzy?BkEyhP?U=GR^32#bv zMx|t=Pnteu;`E8&>;?orG?UZPRy@0U=8`!n>1l7j{`0FBFF&lQNy$(jPAxgz!e~_8 z7+bVkn+*~08^yrM>1+V5gMBLGO#{PcilSO>XeH!ma$OxmN*n+^1)gy63Q5@M`iO%O zn?$~h7y<-CZCLDk+%rouL`C6-kDEW(SzBE*c}79Ssmj|os}OH^SdUV;_2jXGr^XeI zPD)LZNw5Fi&r&l}IyyR=8=9|PzH#S9)$*qoFJ3h_E-AhXrs(AQa!Lgm2L?`yL#qvp zj6^#K#-b{@z1v>~`xVUgHq88`GZ(K}u<4yG@BVD#=<)gLC>#w1GzE(fA-!3Jg)1y7AAV{00FS~0#=b? z5haQuDUlLIO16q6JFdw~96Qf(nw{j&=j3}Y-+Qs0xWJ*zOl49=#g&;crHQ2z$o_)$acV<8-FNE}g`>eh8+UuNqXD~Rpb4SR+MOH~kVR>oU zRZA9s;`e|1fBl~?e(7u9`oDku=Y_b8H&Rpx$&goo17uz6APZy0Npqrsv z)-T*A9d4UOG$vQZ!?v0M$mJYB#3ajXNEcG}FkPv@MylL2%mg_{1|tKHU}P-_am%fI zEQRc4@6jwUfTW;m^ zvk-||IjZlaVXPI<+mZfa902uvj1msK8`N`>pILjRCt{hai( zd=z9m2GiH~DA4Qt#0Ev;rzA!L7LvCtM&R~&g(ZcRHI-{_S+{rRo?mZzcH;C&8-DB7 zN@GXT(<>wuavYU6mnJo%G#LO= zo3&Om4}fMQjdOGJ@(c5`va`;fX~74EB;}F^JGbmk+`6%&>-Cp+oH%?Ud+_bjCr+FO z@!58@2D0R6=FY<9)&}+5hHN_)(`yNtuG`#eS71Yn^$57 zR%R11yRsRukWE_uBcD$~6{zzZf~~~3#ZUW ze(UzYO_{q=K3qba17Bj?y{>R z&MTmT5VyQ;vm&n`H`Zo|bqZJ+%$?WA z%1IyW|DAb2#yZEiEy#2FM=gCFUOG0zfjKFYY;s~Do?fgo(6s+tQ7_7vV+rWoQb!AC z%ZL$U#x=hF;`Y6}_P4ip4t)BOrGCW5Uq01*>g?nh6DCXT7F1{l_1F??*oI`7iy`yxB9SPo0EDs*nY(?P7UzDTLr;*kr;j zIaIcZ=`o*Jh5j#N#%2>c0tM8vj-8Y?D)m#gDqWE0^rh0*;i<714y0j}g?M_g z4mTwvqy$W=Qe>7=IQI$(Afn1mMCou}GN#HESq-h?9t9U)dQdIGT}g3xU`>fjap}_) zOTIs*jEB3no)nDO(8 zph__GS4lp{rWB{%_JF-{;I{;tpWw;{t{c-3Sfv7ea;-~;2^T3yod2n`8mP(|F?0T` z_3yj+yI=U1pZ?pAM~xpdeQx5--6?;id>maJt%CNq!MY^k;B*>To5qnzaQT#zg;r1! z{!7Heaq`b!Jm=|6&p-FXOZ#>os2^36mzQTQMf0iVXCC`ab4%Or{mF-}T(zj6D4!qA z2zvYOdh*d)ksy(n1)e>?#cR9N9ujMV?pYc}^!`tS;2il#))~-k7~{5a}n?A9M7S zA1Ti?*VAF#XG)h&q?61NcwTX3VP&(n50}xn* z15<<%SxBHn5W8%!0tK=cRY@HtTS3mMAiTRO%-}FF{p76*8dZ&~``{;jJ1ZNPaW$fy zV%_?YJ8=>c4Kcvkz=`43#f*XQf6y7TzTVlV<%0Rw|Kss zDBfcZ|8o}4B~6)fA(7pgKu0#*Je`D+&!?q4@u07U+>rvAm1>G+4r7sJx$G!$#GVv0 z6Mg%aKE`V?^H4z!;DrmgNbnc;J^bP`TfY6(f5*qpELt|Nu%w8TTh6w&Hn)vzsCn`E zE&Ja()G(@k`L&B@EtrXW4zL*nYc{Od^5WJ<9(erL4{oTc^IT!sU)0}Jl*gFF+MBQW z=?{MS=8nDp{*51f_z&LOIIfJy7#Yj4vm^%<8l?)~UqTS%g@-Jo0T2;lKIk0FALaM1LK1lTa8`22%Bu*Wz zjU<8kF^vO{4;ROU)4EYs#;z9rX>t~Vq-N?!CmeGCK8tW*iU`C!$s~fgtlb(&FNZs#+jzt!JBCTJrPrD=VrB3JUpIy{o&k z%xV;bLZMlo;rE%Tr+N=Q(ak8QB{+lSAb+^=lL^dPIvdHd+Fe< zE!pMeWu;{mZLQ6n9UbL~Mnj*9+0oH{^5lt@)-zD3fL3*FVPOgNf*E9Cz6YOSSW#A9 zHuRbey?_@M7FJZ$43f=%M9TD@mFit0ci(s}0L7-+j@f4aK2tkR`QNtun`G}AY znIw{J5Q{7YWJif3_N17ZLf;fhxZS0|Tmj2=q)AF2ah(moA?`V-o)i_5g2$m>K^1z) zJOvJ@mJLuX50!EXLvF-@RfPMx%$wz4T0z!#Nio3;TqQ1?4IqFT(HJ!KnZqe6G;u59 zS7yz?z%{ABa;EVV7(u||;hsIK$7-_0MTkA-OUsL|xb{lBhpZ6`mtU?@HT5;iZ@7j) z&+C6r$M8j-u}>!H=Q|kdc$pNMvfZT0Za%`qYD>sN63Hq()V}2u@ASE95-<^ zYsU@}&zJpQICX5Zm@;$HhC9}N=c_-s_xnHJ^TxjMlg3unRJOODJ9gyMtIusY^wyyj zYp=Xw#iD}30;7EBqnB{C=DU7t-Npx>eCfH@@N#^hqpjUQ%Jq+<#B~Jul#u5DgVrWW ziVwZQQEMW8WotVRXzF{&o)jNp$>HBokmXns%$R6aexQP>nHqw#;KT2|9ktR-UT*G( z`az|cS#O?09Dcm`q#~=~g=%Wy<=ip67c(hf1=E9C6UmgdENChWE28XiFY@VeV5szIW)*eyOl*>DBLl-$#3UyB~V^fnWUU z$0O_N*WGZ-HCJEHoUO0FeD}TIId<&mvSlmYf9LNO7biabuchV8)6YKk*ydjyJ$4iu zg80kt0X0 zSaQ{E@BTndZQ{>q_)6eE{lT23%kF&7hq4B@O8YZuu3ktvp93M}ep5rQ7pR04C@a6K zQ$jLjZFp-bo^I@VRMJD+_0)l`o)l&I$-!mOui&b>gvVL|%XTEz-F4JST4mMH+TL|6 zUAR@m3i{oqug^_ciUUp|V2}d>Ll#HsFrNb{E}4i!IIeQ|)Je$+!L&;h^J0^1gj*Pj41FH3KR68|f z9T~XPT)3?D$K=+@;YD{cbOP^_z5eEvd3m|N`st?UpL+S_=U*!-F3Qc#?da&p&&|E{ zgX<6PKZ;vJ;sa!^zV3?3>cP*>=0;grqsBGNUp)8V{-byk^pz_Y<>cnZc}+L)beH&t z%Lc`98@4x%u%@02Vlf^{^xVX5jt$ozE`VV~U&YRw|X4L4B ziCYTvf04|P4d-6o$+JOvw4L+Kb6rU;_fFRES8LL_~gaRR4 zLT3UG72k!sXn}lHE{E(;f-1%l9+-wo^bh&?+maS6F5e;{UTpaNlWkW=WL+t1({Rvp zTs)~*`f}okNtOeaB3k-Qp+Hh*3-09)NWpYs3UV(dbAT2yHak+Shojo5p(*I;=;>nB zv4DLS+72B*vgxJPy@$G6+Ir8OztG#4m6M%QQkY*+);w|a$QAR;CnVmqk>(d$`$pKKSv%qN4ixQ3ZuXsCs?ti)Wk9E||AyWc?^K zs;;iWb&)&Y*!t?0=dZqcO~S*43opI=o4daG4~LKJn=yU%tXcDb?cTj()8oH-Yv1mV ze7Nt*D^>zIe&WbePjBA6XXn8~`^JnJf8!0e(N8eo&pfj!KQE6bpU_NvF~9-gl)`tt zKAa8?pg>D)wg}W91=EQs$i0}%LBw4Ak1rI<2emo0&cyX7bKzDIE9iHdzP`2TB~Yz^ zWjhiK|LXg*s=-AV`p=}{HsD{45O2icWv3sUep!(D=WF^CjO-P+1+v3FKH_^K)KR7kI5hmb+1gk)+i zSZ%zucu;%bUq0NPV?k69h)hf_RDu{X(@aGoejEO5q&2|t6sgk(eudk;uYR|=rFF@QMfezQe&x-ru`T#JfB4=N z*DWcpt}Lr4#pQ>e`rH4Bw<=AXI)Oh@hRKP5{byyBSCxJG|9qmQxpm^?vDw+!c^vwh zIeXe?{^3vZ3kqrz@25`uV8FLA6B_UM@XfR5%{;X4P;<+X@B#@KPxUTsY?FjdbnD%;hK_?qI2!%%SubDYb%DjrWW--n&_pp6iDZs zlsX)%Q<3x)$mEL?@N;IM4gFl;b!jds=&=;M@<~aLfAhDQq{^s4n&naQ8OY1I%ja7u zO&mVlGL=FsJ90Tcy<4+PLjZ#G$%T?bF*D6nB;vQ>&qj)8kbs@XmVHO}|9ErTn+N*3 zhu+i)yr-od|96}?)7jG2xcc(yrimHsE2aFaK}R8k15KjI<0~7f*mLL=Gjiy+7g=r+ zXx8Dl$qYo4bLy+WK&i~p_-QZj2|JKR1~2MQ2G^S4B}r4$ypMhKQ~VOQq_{XAZ_RFN zK}l^*Z9zft_MNZodUM;H*@?GC?LYX|%db42msi->ka$T$FXFIE_~66$_x5#v@ckcO zxneEeJl)&dv+uy(|M=mz9^Lqh3F9Wd?~dR3!25o$x3}-uv4h|Hj|+SD?t0hFcg~-? z1TUk?%F8>?pO^Mdo;>#WlMlVM@6DUm-+uE=caE$XiR|%{$L{~dJ@-HGqsJcm#n9<{<&R{GL3az2g2=^De9w?PP%n=xr^qn#q3rSL>6?^NhbJjF zIFKiqj0f}b-~7UYrpspF|DNt1{Jp)m_d-Dd-qw?c8=>MHzy6lhl5EDaE}Mb>8DIPE zRUF%cLI3zXnHwD9jblfyzhkW~zi^?nymZNm1s6sv=2W`TARL?VdO1vX+*_ekHwTz;@*J?XO}(aKpM=YH9}F zG+k9Q^2EubFTeWy_8qVH^maEkj!E#dN0gS86&4iM)isP6Gmh5eMDljz;iCt(@7Owi z{KOm9-aLKkEHYMAU3Xo_xz}EM>5VsWIpl%C?Z^w4&0pMD-?(w}gHJxa@s8U+l%19I z_h8$aTPI4OS}yaDd1^VOL>RZWP+1;I5yXi{pW{n16~$Afm=QzyNc23KIHewaQ1?Gt zBii?;r3V*(IQqJJ{ddDNupNmdqoYV`$ES;0O36_vS4lXHP7%eo4&}*Y5?Du_*=lGO zD#YJKkui+RM~T%^pbk&dbV7(r$Ayy_WL>T#ju=cU8KBh?<5I`U1gJP%ixkC*z&Hs? zJ|Xz7O8h#C%pw^HSlsGUwP3Qsok)brQj-wL32ExD22q(bqW%dbp{VDcK34g20seM zBIkhoUCg4md?*Y$Br&{@#2@x?tN6q)IYoq|?>x;UR192xTsU<^*4A1bUpf;MNWB3H z6GI{6A?K~)a4qC0Rs_bWpX9RxRpQrCWEROtz`rrrS1P#&lat}35$2L{!y(?J3F4b> z`>(PwVObg)^Bg${t&GEicm~%q9joLdj`svNosMQN>A|%ulA{z{C4O38SI^m9ht6zy zGclcq-257a!xYV&S0V)pHB1N$xFiwJ!Pl)Zi>zT}j+cZ+d_<7R4R>`VN zPKJ|4&)pT`r)nH zS3U96Mpij>^2FIQr?I_QU0sJG(~sl*8`+~8$Br92xwYjizA7scACD9Ab^6Silc$c& zoIZQ}xQUYAP(K=bqg%GTaO(7lK|HIrrf&K2H7~vT!jn%uvgq<)9O&QyEF`6ffTnR z`LOZv@SG*=hB;itsmLJkRP*Sh)7>R&t^=41lIKY0%s0EpROF&bvMCjbP&u!X$+zkQ zWMdqdwV6V=9S5YWhO(V!Mv`X6&?HdJi*d8Wib%eavho=u=dKdUGT97Vg^C$MqbZn7 zo|;TCNR6-&frYefM<^pfG*aWr5LqUN1TZh>LM4Izi7YMA86*+Y!{AwFW`+>g`VTSs zBW2#4h~m?bk5?eyD#B6+4;-wnt}ZDZ_&@+?vc9X3IVl)WQ+BoR;e1N@kCRWr$nimy znhYYKrZQ`2Bpwg0FPZX-KtYmSf*zd8$*`4X)d$GNIIxyy0_4+x15#B(*^19P35LYb zB$_&184nQ|;NVj=I$E-Fa2L^dFY9r6pr?fHLhB(=_nyvHk5x50E3?0m<>>N>D`&+Z zNzrdVl3;BC5QUL6bu>fFyV_ZC>xpCqt>plvB|23kVpfHShv-0AN%{M-6B1X>IU6f5 zrQm@&YLnKaZ&q-oT36;BZ*~dee+;n zXBTcLdHRJNqpw-m)qL*2PanrCg1oY#>IKuVm3jWqsqDO*?EHM%j+7lkVxF>`5rgE} zbgq<{uW$@yO13bWifkqLLF&cF$RtZMoCV0nIN)r?18&CwwNOLZ*E1tYGh=8HXv|AJ z^O>azk^IPKK3NO0=gE??lCR>xULiD^!VvTM7oPfyzxp_>PC189o_u1ytg~cVsCFQazF)k}>J9iEro5MJ6t37Y-f~&7u z{j&$|-u(DOpT} zSo2aMH0*}g&swG**DI7If{!WYO^@`;e+;^JMRQ)7gy(Iq-SHo0=-zjycPoIA9mk-XG)ii&> zf|8QNjmZPQQah95Rhg?#=5P$?%nS6IKF!sR@wer)ur{@V{uSU^_-UET4)@Q6hko_Y z_HEm*y>|7?Sxxe`o9WCTr?_fF*1pyQFZcg`mHebIq^o{m7)p4fpsK!Z+)j&sweae_ z-AQukTBEnyUF!XNdwaKS+xFO_o7Ua9ZpyT&QYhXto2>kQD^6`4%&Okb?$)=CCe|h^ zE5E9&W^gYO0A77%=jE1-uP>c68gJ#stt7D#iEAPIy1LKpI|2f-ou^wmj-CWHx4gKo zGqI;R9AB+6P| zRH!Jgy86o1FTV8LlTU42y7(%Y8s^wM&~MKX9Oh}pmQUgN2R&j{;iR+j|KEL9iG36@ zTc94>q2hdy1N|^x>PQeeNGUY}ffEzo8Id|N zmYTr3dVBYrI{nJgqlcSYTRJ=O{8V08Fsh<__Q;XbYHRXx)V(dFQgp0(ae)hGf-5i! zO@wMCPHQ3&it*iCE=dcw;bNL*fSfpf;>_vOHMO;#%|WXm8}^tqQ*7OG_(JH9g>3;GK+r_NSj}XlQ7f)pX;H>n2Z~A`3mS zQnNDsuB;^hQ>${Ru}>+>psPxfiNhsk$|b^hUr%q>sl@eD*||9dwH4SM;S|Yc;~m)( zM$w%mFLd{uKY1p5@TCjh!QFA>B@!^W%?g z+_Y-->XoZkjvg~wGAYJ)bGalf+=h!uc!;K7VK9r|=jdmGXXD9}8kEuhLepOrNnpir zCQ1^^KWo<9kALjXseMjPZbjwL$KnCZ&d!}VeNNMic^fx9h)W&EkDa)1;g!daALf2J z*|`PzMd;eqmGp@;AoTQfclUJV7ZeqjSYF=_t|Tc9O~0TJZ{R8E-^1)G$SdSV-Cf;1 z-JSS=@1o)|2K$4o88c>IvtrG^%j;q{*vkr08P!@Vpl6EPMwF1#Q1+nDP zA*XOAC7$Qf__614o2j$SszV@|Cl3;lvCOjD)8{)j?|JL7H+CIwZtm#l=;`Xh(>ZQs zR#;g0Ol8ILsZ&-@o>)~l^bAXd!?9M@*|AU4e7Fa>IHrzj5i;~xRlUsh-?4qi{Xf0` z<(FP+Yj2-4apHB?Uw6gQD{5+MxGyWXy0~ywFax}_y7H!nE(Jl07*naRN9y1j%HyB2<(v{&s`a29O5eFX?MK&)am}u&hv0M1q0m0 z)q-a&37y@YV=G4$=ju&ED9cH4qmZDou&f|A|4_@(mh)|j@6xn-Q?;&PM|bDM>M@16 zNq;d08JcpY)^a}wJK7K2|G*pDcNUctPo6w!<*HSeExe5PCE2r8IW@4lre@9B>#%YE zr7wKxp$8w@v}qGg;~Q?cp{}lO^w_bnu9qYO{s~%5(!3&<&RQD*rn>c}Wvw~{_Ww{D z;7dftl(vd+C+IMN1I}_l?K}i#h6e#^^5}#gFJeCW@FRFu+`n%>{z5Z#>a-iyT|aBi zZ0t*h8)T-m!5{_Hi{jaHW`Fp1KJxj$``q0>yt}Hp>haB+m#?^b{Y~rTrX-|LA_ZAN zg4`vqdUQ?+J~@So^xlq+UfmqG=-2~ zfmDZ(A^-lvL82-vuOz=JvFm{^*E&nX#H8=UHxEj0Pw)Ld``Isl{@{rd$Lkv!W;V@S zzy3zGFDZ3o%pN3&m~q*TV_b|gX55(D-gDcbLkGY9PhTH3Y819tR$Q~<#`WtaPjPNa zy5=QYhj1ic#_U;w@3c7X1hMbm*V&zz1ulpxN_`ofh`xupa6h2Gw-=JSQsLt7HK zpmOy3D@(`K=M)tdln>mBgcbOJGLl+g)ge$G6p6@~np#EJk6bqoliGO*EQsL5i4(uL z|L2=GKYr@;>B9W{`Sa(0=p!GZO-V@$QIj0Mhk}BFIrHaSzwY{f{`xmR_u0?lkvnGW zm<0&Pe2&9Oqz~mA)0cI+rN zAMvq#yrj*|XO15~g4;WmmJU5yA(;H7UXffVEiS95sBCUId-~K#e3s52`NG+jv!_m- z#GNbUwV9MHDk{Bl*~%AQeCnxZHqDwjf3PI-U)RbrE|0a0Bb7~pc_32bm{o+ZXBXk^ z$MHz?Jb11C{LeGil1t!ltUxNS;Mqsq1MfOZyieua%^8?BOoWT6m&*y$E<#{>kO*R{ z)uU}~5AJy5(QP}sxbuDQ zosxVN1J_DHPt5*Zoar-WeDaU}5ZU|gz4xUTUqb7Rk8ZkR*%jFCnKWrqUHwSBmD^Pt zDXt2oc^ONZ3zr;<$5qy?n|8&yu67LRCs?c>v#uYiRw;lr#TE5sdMU61WG2>95V<6; z;JV2RPc%$c{s$EMb=8cl-pGLZ4yLLFp{BU9wiJ64Lth@-oZgeE#$nXE#_cuPe z@#U9ZrVl8rs;b(wantR0+pS$+%?dm3ogRW!w9M~xo+)Duts&hPz?8`j+*$z+LiXicn| z(z#qZi6h320uowBm|M=n8@A}5>UAw3XKkg44 zJn+E}e`wj2%d8$eAa|tcc$i34lvk{|ehq@JeDNzk`{~bcLB^vK9^JHYBmO$oP+z}t z)k^s$ZmH~_zCSB1F%yB*mz@6Ce8mUY3L2$Dfc2H*P}&`ZJYM38OY<_8E?l_e48#Tb zg-e%S{XakUX)McC%U1PoI402=Mq;;l+^*d_ckSASkG`SR>*#3Tx_!&OeS2m%&96y( z#Etb!*+|T^zM*mKm`QKGx$U*rU#zbmh35wRY{bQqJKxy4W$PW4HXjk^H!WCxS+ z|K$fxU@1T76hGt{LzT#(X~RUAPrY1Dn06_MwF2>exbpGN;9&^VbkST+z*i|!AW&sp z;vq|^q(VvpIk3vcsHCT_?}bB09(sLSVt*kktEi}GLPPz8>S{c%>^^yN-?8JkSn_O3 z%cI+ORF{_AJY%~2RaB{^Y?~~pwIzE>lR0u03T9qtoaGb;oZ{f2haUdUx4*M@?_Nf5 z{yg*aGq{Ls*_BsfJJM6jC4qrCBU~f-o;%-z@-O_u7xuli51W#RUwH0${Gu{`{Dl5D zaj!J@C1pEPQ)W4K(d6v*9N3+I^Z-ie*I4A-vadAnY4wrL4Cz2|sA)jLdJCYYh{N%?!Id&8mY9#(ZRo_tm_+y*#Zf;yg$NQ2b zi*l~zMj-*JloM6uajTeW=$BPw&oH;VaYWJ+yp!vni#ZVwpHm!g*T{LP65=!1-P7}} zfByErefN7Th}5Bj2TvS7zG(3xY)AU_WtKHbN~Ty`1^qjc&- z-Z*(m;?H9qK~8a?8mS>@%B^2tl}8RCPMknLk%AxuYML~>M?0^otn>KkzMj606V14l zWKO}*+Y!L;?dWWMYvBE%+4(s|b(P&M=g3k{etyY>M&64w#Y3DE2elbo4cV8DrbM_S z0@XCqhEj9!@dj;d*~`61^g4O+`1O% zgWaC5eCaFCKlj|u9Xs)Q1Am3Z+krQ3+_-x6YHUiTbFn0uVQ`WIEG^?Z@6;;5muJcMH{3{|A1cxl$BM_nBJ71d>^C<NAg3 z(}TyqN||az01jEiE`q5;e1+4{bV&Lxagz@n1~~vIbqIg)g%|O;97m7#-vM2?IdbR- zUh&k_RL__-Q@}Ap`Vo_8>?j;0Oq?)b?AWo`P67}6q1co}@N9Ycv(G*|Y2u^_6DQ(L z+)I`&##Lz!Atx-S&m@3|X+34k1ZyBy&m@>f$3&<~T2dS}BtAec;Q%Q~mN8P2Mk1c? zs+O59oC($dtqD)9+^T7UxHR%tzkKk>k;HBI)z_BSwtV|r-@5Y3t4hlBJI)j{bU1jH zCY{Crvqs=^3#Uw-iaq`I_QZu$r%s*1z9jy~2KuH=8xwEhUbn8Uws!QmvFdtBZXLNw zAkdQXG35Xv_2pxl+m54q#|XOY#p;#2nE~-#k57}&NdNqK6gyDmYUq+9x5$!V9x_6O zk!n_g6_ie>UBp`DWf#^|*XLy?Zmqz_F7hZ-o4tGX;>yJH=k?vAkp02G|DdCz6FZW* zdC4znJg!xW>?dsXcQdntIIdaSzJ2@F*SF&K`Jk$ONqQ6atfpDnIf)OK3!GYi%cO7e zG)_wNKYb}A<3nqAfM`TBlCi%#9c9-#n7E0MzOuk%ZEqg{@KAp=i z^z@!Pdb+KDJ2LUE?c$>RvSQwj9FdcqXWV%?(gTZg5I#XJk%1~#Lohv-n2dll8L+D&od1y(lJ7hj2! zI?*&6 zhp#~LYj(R;QWX!m+=Vki2HZ<|kQ|Q9UYtUp`m(a*M8+h7?`AKh zA1D<`+j;u*n}-i4`ekKJYHYZ5R@0cu3hGr=R9Kvscc%UPZ+7hL>+Rit{KS?c$Eu5q zp5C{A^UlP3D5fMeUJbHvS0hLL- z)AsmD+ytr~*OaN$a12kKJY}A2v@c0-;--rw|LBiDSzz46IouCgsHXziRHdx4#5b97 zbx7RQQB6(5frU!P97@B;y5^+{rD2yhQp39+uz#V-F~puAb~FwiI#@rl-ppnkLA%zL zq>tz!4+jq%JbdWzOE0~IPwi-EY?wQ5-pBsnV@UC`$eCu`RceA{i5-#%K#3^ZJR}Ye z*?$;r>HZkmvlo730T)4co8p@>2p5%KmPdY z<<}lNdiZ>2J3ikAANV<8+@$H#ny@?BUzqS;RaJBAyWV%U`RwG0(^##%qUzl@-?4o8 z>S#7K6r3{TYKUXpR$sQ(WA%BoIb0jX5iU=&g{$ozd!%mQH?_{ zP4UT`OBY}L*MIf*M~)nvF{3GK#PGe|kjDry{x{cIa(J!?2uEO+NZQ*F(<4PavH(g% zJ(_q(93HZJ7tRD3a4*?W5@c-ll3CT4l_lp=9KHgH?MOZ<98f(E2-FS;GbPR-yW(0#WFU1+TxOMy*+`bFxVP`ZSzPDSo_N88Pp_L@ zUq8B{0x>%uS6R{2F!Hs%Z{Y_sK6K@9OG{5r@3GdlqbE*)*4);XlbcI#b3517h4+0c zeewp717(kK;!qZ#qDfgDps%9IqBxD=GQhwtiJSWu9UT^sS&?0kZgF1hI!qd?vdBylWf(yRsJzr%g$`y?6ah8*n#C>PwUL&rvejt!rgMq*%lg zNF8}ZByKG?2`;O$JY*f1BmGDbjzc;!IpqLd@MZ}7pRz(m3KJJg9yoaN#L1g(zG=ox z{6?MhH5^MJ*Iyp)NL-R-Opbw&f|*L1Qt-g#M+xB59xJDgt)4&a)C=1$^z^jtJ+%L) zkDq#BTXsP~-}#R7$IiCxJ%|$QPL_;ms9Q8s%90ci_ROYI&Z zd0@=WS&b1kKvGPwDw8*)I526gt*y8jtBGG^1TT-9&o*~sA2)tHP1EW@PFc=N z;wYq)0vR`y`6r#ZH$*8DH#CMem?(upYn2r=)glq8!~RhvajS?XG&YWD6x(^KX6~Ft z{Rmp<8P(P{th(j~!M~7`oildKgt24BCq5Coy9*ylm6w;_e-x6M;>w%_mnCiyAiwZO z(E0P0NCbtowe`zu>-#QT=x+n?|Qs2E*~#4CRre^ht*);c>i+WgiGTo+N#AR zWX4*>CYg+YnH*U4FM<9Qh;enKM~Dudhs?cYP@GY>HJA|GonXPGktVnWhv3#YG>||8 z!QCwo+=9CXXxy6M1b26WyGwAF>G$5b-%Qn1O?~rk`hVB6&OXmM_3XXZUJEyS(^QGU z=U5qOKEQngG@(h(D6Pzt1+-l;v>0=b{i3Hnxj%&N7dDybTnqc0oR~~V@>LktP*hY` zTejCVkxjR{6>>6=URoMbI^-FRRXC_)TRje^Zt=WYUHYW`{2agle&j=j7tksA?T!yL zij;7;EylpFl~~hIl_Y@v=W#^CqjCj$tcS*yi|5<5pFbUA|KQHg&#PvLo|NtkSLyqi zZ6-waN425qShJK3ftG2;6VYiZV;ciWqU{Y)xa^h9hTiK-A8tv#KT}#i^cb#aU^A|1 zEufcB!vH4j1XOS0yIem^u84I#(;kd2#{W=-bdg4qE3)e~Lte}-hlhqBb5Vt)5Rs!} zr=6~nlqr5J2yv6ZNx4RSU?seZpy_x_`_4lD3(2n3Cm3XEd(cBHrR;Is4G#2Qkl zS_l_%_&Ks>=G<#+-R>CczhFknruu|9?F<+-;kK`nDMm|hLn#PXK z0fQST$IG>vP=8~i53c*1?=F{2$NR-7I)e4Df4V>COEZV$3; zYYBq~W5~BRIP&ib{&}X1W{6nC(0+Np*AhTDvt> zah#Z#sD6Q7&l`pQ^v~2aQs;D>l2&=vLG?I7c@(ASzMJ5i5cVTZt4r4n%FF~mt%*6- zv6ZkP*HdF$M1o49bW9CVgrIYppIrIL*|~m+9c%C8uZm>lVu~f;X5=uBn9Os>W26K^ z7a8cYh~z#Mc&-S(3EN@{jS~^SNeog1K!wUQwVNVF9%e_Jvg6X?kIaE2m(P?*{u;9Q zWA5$^?pUoyANl!<*B3LZ<&+EhWBb{!Pk5^ZbZh0&2Dv8he6Q=~H&+n~3mFoMRKl0T=p~H4 zl1&LYb0nE9>sy;YTR{al0@DH;5z|UC?}S-G-?LCA;07||-6irdYwyY@i>Jk(==Q5h zg?_s%_=YfsW8n*+@Y{x`@v zTtv6ULA#bPw(iJ|d0^PU`FRlugJk`-K+t(_bDjZ|`%ggv3Q^+malglz`Kl~tY&IJ! z19W;r?qoZe6(#S@rLy*625zt!y~Ophx!0~{Pi>?C;0I^v04f{RVe^2KSD;Ejc<)%)@>Cs zzmVRW^wW%htbf%-YqB_fx_x1FWf|Ge*nQK8_HX=H8CzMl*dMp^k&Ezf)fzdE4*wq( zbO`Y3iWUX+xE!alD7V(oznp=KmZ`XsGu`lx*548}nebs-z8I_46A}0%+@UoJHia4` zlXCV|iP$TlsK5Li3pUom_rt=FIsNY*!}Iw#Vu8Xk3J7jAReE(!eOCwE>E@dpQok`s znVDU2Wgm2&1iA~jhYaQ5@v%iiI&?t!zXHVLd#c!1AVSK`{zD>6R3J|Hd$|-h`d$%V zikz>;9NX7%3VJO)D#oV?m0!&$xaWzw2D-h32_No8G~ra1u#x{Xqf7Tog%aEk8B~4-S)T95lDC*=+GYvv9!{TA%l;1iJ3K_NJs{GtX|xo%p(Q6H%P`+E5-ZFOrw$(693DS+DM~B? z&vC({l#Ixwr#BbkE#>)zPGe14iMzm_Yl{{N(~_71*VM&!iajN$y4||0xsUsLyi)#{ zuWUNyc)8#hW(^TN(2L+6t&4|81U%e1-|jZA(|)z{{CKl9TKi^xwc4#)nV-EQIJT+D zIoeRJ5*@9;f;4odZ>OM~s!a{3_}{CAz2%i_VtYG2JpMS`{-EgKcOPlYh2f_(vTCBlq+JIvH4bJ4X;B%unk2@{p4QUnNsT&|J z=+J_$%Y6P@ef6&)6d#vXFgqNwx09$!W~gguBx(kBUurt7-GlerE)^m#OBF7M6%_hs ziRKt8tD#k0hY{ga;Ltz8$6c+ka|wTs8*g&L8~6-8n-bqU{ZL$}o;58L`yfM!pwpN@ zr(D-@$=j}FPxkf$#d2KFp8}|IPK+3ya+**t8z-J+2T_Jr&=(G;ExMn6r?olWHLveJ z9FZTAup5Dm#qd+4_(rX|9z#p5LqylrBWiPxVkc>J6&tmYfs~)%9(;am!)6r5fM!}V zU*}tvW<4&tYTn5yd)d*v-N&X0RfRmAbucy6n37RNKwMoW*7Qs{g+_GQ$c&849bPC3 zA=&iA^E#Yv6>3G##>^30Vpw*Rq@q#xRO*T zXopu5Ev|U+%F^V>2n5uGIs#&sU{5~a*|7HcP3L#|OXE~ZZTug*e4-7a`=8s$V&#GBMa&>Vz@rZ)n?JqP_q13adSwEi4RIHXk0A zRRxmzs2~6dBU<7&e5Nft*f|Zk9#tzQHQ}ETmLgfW%z2wbYacc~WWs5wjialH0y)kLkX)b!jahf-DKY_OyHy=`3 zKkYp>mc$84@QK5NgWhShozi|~1`csvSvp*Gm$x`ATp>n-jo;I4&7N6sKM$&WnhJ+Vx*wDb1+{u(#i@LjeOi{OPSOpQ%cH#;kiACOX5IKbK7cEzb%gjdw zQ)>Gp^z!I}P7s@~P3}6ohenZdXOh|8cpXVf^7=EtgDK2r8}F*}Py9(x6D0{fxtKNhEr;6_gZY^IQP+t|U7 z=$=>(`%T=ujJyed;SY7%yw29&F28%|BcC757Dg!|91}9b|mJK_iSo9&@+G78RL`m!zxn7rm7D|Gv{Zn{amEN0>JL1;z!u(I8~E zd?yphvDW@7uBD~r>AA2Cz?+xRS60v~<3yV>+ID}rZqa}${WIK*qEj5VLsRq2?48-?QUB>9RiwE??cS~(H5@HS#zO|FRwo3iA|AkapV$AO2 z*3-Ow2Zx-V+@tV=fk4(Zm0CNjh>xpxmx%j0=H1Nqs( zi9W%0q_`xF%WLMS<#s*GWC@$1xe9XBVm?Ztr5$1p?>`5e2pcaow7=?^_WWI3940#~ zbaEqpVfl1+$>Hz-TaMa@aP?@5O7#iH&Dmqn2x9=6Lhs=>%EEBz;3i=)%VMG?1+S(r zfYp2Oz1|1fTsOG3>O{Z~5Bbpc+s$_rXV7ENzYJI*LA;5h3P2hwjTpaLRmOaIN8;ic znYMN&P*}k%=9yNB`$r;%a*0Ea7|GPRb9udC3YzZVKEm=^umvJnKLy3IKW@e@)yv6G zO)agl%oR-_F!X9T-4CpyM5$$s|2JCHDV+<$S5S#tWg(=@e*MaDCulgo?-EOs zvCY~j>Ss^J`C)=zv%0lENfNJjO`Blniv4LHISvVjVW44YBYAEqM-{%iy02&0cO_Wh zL7EKhJ0120?r0e+7KBHa4IoZ@)zbjs<^WrrgCGs>PdPUgK0aPejSpVvX#}hthVD0~ zE5iOSKCQJ$?leh({y^y)=cjrtZL9@M2Y(Xtz&=O0^%UOc#>`Zq(mDE|zwel;)#@+^ znf|j!+-Bs!o!EnEDE$qCA?-!6MC>=3p)dHxqemkp+)S$3PaP<0KQl;5_{~^xUVtpt zHz4060ae7e+^klS$#$t9UE_8_H3P*2d!*F7-7~3d1;q1#)DpS9tjr29MPRg+N@_SteSr0TvMup)2F7A5Wd!R7_^>vh=JkhC}6H`(+(mrBCeDE9U02Ky8&YCdC!p)0Ow`?Yabjp;V zE0Bc@x@>GSGvy8L2ZtlQ+Ke-R(BFl(_*_U;TmeCBKg@M6R8;1cD$pg&Y{8wJ#Um=;b9={bjoUy|TjuVCr%UGC9oTiS4 zSz`f|DPMBiGT917?qO}1v}NLFa!6hYlv{m%@5*;M2PsH3fYaJ@QpIg1YnT1QAO@fMj zF`|=;^^T9&$8y=$_G@hx_HJ(2l@dk5M?a#x6|2$(eeZ)DJ)HSxr`9&{*?V&IMu=gP z;(51JWvccvEb=-e`phNYvf_vmwZaHkr(b3Mc)SNG?tFj_i@9F;Q1N>43@SLL@b?ww zvLy=iZ?pH%9Mn>yu=TzLX>%)`q1f~hp7Xd`!qOvSvF(I-mhfbSbhNm7&A(xpDICgW zDnh>4enm1%ifRfzW<|?>|8MdJ@(lRBHP@plg{%w()f2PCQ1tEOC#Pu6Haje+-0$H# z6p(Gb(LyWHlOWL}6Cc@xl^g@_=RU`GI0_^vf>6PR-%ddjLUgRmWw@6Q)E&`-|AV*v zVx9EAp~n30K}8vJtYKGS|E-}lCqGFxM0g}Fv8nw!N31=e4>TBP13Do){?@3~kUB(~ zfUuCyN%qEdH~P$6#Fy%m0G9x-V1yV=Q;J*`8eZ0YTE%x^dSJe!6^6A%IeY zK*##6E{nh(Yp|eHpO|Hbh zP396yk|i}Xf?{=jMQFEu_8VJ@pFNPDy&rqzVE*VPnq)I)wX_aRjqkkkiK8ChRbEWT zMj!RlyjD&;O^^?Ra4RXsG}%L3-?z=2*&E8wx55b`pgL{eY*8IfeE#agA`bHT zkeq+|7XC@_C$_?$Xl_W6Kpfa}c!XczDxN$8=IP_!&*7cM{I=ssdr%xFHvkH9C?F6O zF}dGn0hmQ^J=8?t{{P{LKrD{k|9k~=T%0TJ!7FgqI|iMNOn;I%bU=L@p`@#+n8AH8 z*b#h?O5gB$ZXucb!MGT6&lFK8&FC3Xs+F8`9hY!)H}xl z=Sq+>=QTr@)|95%;%0(C=LPyrUZ3Lno=L|H(0)dl$|u$6g4A4 zJ|K8+ybhg)%fbdHEdvSLYh9Z0vrOC3S{3|3-9M22GHX4`~?9k9cOv`o_DK#4p0Vx+-whydW$ z88J-3H7IQMP55xrfG(QRHcjT)a!H*_28{GbsGz*BZ@W3IW=1C=9>9&9Haq2bYAkeW zLKYU;SgzpLEKF>cCf^l0G43XEXlgh(jQihZS zy_+5GkkyH%G7#ovl70OPa(&K9@H;=l6-XK*9i~N9qNJ|mRwi~*+7rW(LeL4f{9LM@ zZ2b%&{ET1F7}aesz%q+RtNo51xJPTs;1L(<(B)h-H)T6yUVan{BO^~ZGVcQ!D=R=p z-b1t%`mHA!kS#DPEk_Z!hsg7)Fu4jCp5AM_oLS0+wlFT1rAd_<;+e zJl~qyVFhF)KzOHtB3Is&x=L%%qM*bLihzNSnDRfyQ}>!ttV(K^#NgJDQ9BBfkg|40 zehMPU7OqRS)8I2T3{9Nd1mLPKFxrj6q%0bt-|)LG>biUrf?t0gHeL~J#*ceO5e zwN4QVH7ooLkL*yvO=bPTLUM|qC@_mq!(#AP|zVG@P z(Ha^Yxz(j?4WJ9JEq)#5kp~b^mCHRFm%;kO7tkPOH8f8?{Y_8HPwKo0PM=dzWH8eA zDj@5+!!*4<2oBVOhnY-BqR#ZwWC>IZv}wzUiE&(mS^aLV*~=oEk~wTjJx)Tq(2une z&SvuPBs};f*Q0B5d1?{DX4&IRlW0Ek@bKe?E>n{QS-U}7jnK2ZJae-&US9D}W?n&h zuNtiiz05v^I4_mnXC--H!R!`(LK+~*z=67=6nG4NNDPi27eIlZ1$mYKwa6wccvU_ ziN|s;osfj6)E#h`cyj8963rBmHP-+t5**%DY?H#FLj|Hnt&=@) z?dX4Ih5nO25%YeM_O|8<$VM#BARO$u%P;NP?1Patle`_LZ5vTTlD^3n^m=$|qWOmS z#CpkyAf&7@a{wQmj3|p#_!!)^9I(WO$*}5)SM_m>AV!3*vnIPxTkY_!w_-(v0m~ID ziKtHS>8twbEC6F?TDzR@>5ww5N5a=!TABp?o#6RZFn7t;tI;uZ86|{m-?=di zU~9Ok4*-gxDQ}g`GR8)A3O0w@%U-a>n{7rXqK5otEH|Oej30@0PQ`dFLu?fYsCBML z0y7^iH^uQ#FvlEL-n!Yh^s^1hqG)S-nC|zjPq9jR-t++;c z%D;{9oa53@D1{}ig_q%NwTjLeDRI}z-CK|1<8(#KC4^?jcMv`*DRXOUvB8Gh@jKNj zYQw7&A+*{crN=7rvnT|MJ$=?DwYk~#&vWH4;|K`;>s$AM_CV8>OLyAvtRXlX`AEt*-9aYY2=fH~%wYjFe6FTkJ-R zqiVJyTs3X2k{A}&VZ)jt>F^dpTE~OsDob-STS&fmpv6!e`u5(AX^XMQ1vJNItatnt za->)~r(`5h(>&Q7o+LCqc{3AG;jKz_qf-L-@@*%y+``oI^g>OI-lqgaX^;?_PT$n(e;S;8wDx)`okd^ehKo2tc`uCP<>|rPL zw`N*Nhs7}1GLg4JVxuAre#+Q$y1Dk;XFha#4=6>+`h}@gx^p<8KRFuI-x-=3cW`!+ z4sGX94Slz>Y^(xtM*41mq0IQLG{0=F0MEoq0cvd&Z+cCQma@sCL;Xz{Uu48|uwOMx z72m=clj~RM`=amIpL85^zr{&;BJ|>`Bb)FdEU@v%$&D)E74x?B>jx;gGEx3!B==hw z(Aoq972kUYLzKn3XHw=vnZ7ON&)roD(9}Wohr+;bcY=7we>d>D^ZYFRpr>xre$$_c zH1CCARgr!w=yxYR96oodqC-Qj9SjsbyE%=?XD`whGT;op_+>(pki-zsoT?qrhfC&|t~1Ul-CaUzutjUaplU=y-?F&i&8-k?{oXudb1 zC6}0By(hQ4ekb&?%9HbF;IjWcANvi>cawJa{`UWd)DygZk-u&_Idy=;1@RJ7(GBEobRko><>#3J~T`ua^W?RB^CmqvolZ|H%#d--HUN& z;&F263&GHzl5lDk+b}gujOb(Kw7}NvrAn~HixjIrJ<5Z3z z$L|0s8VSYKN(S)DoWBPMm5*y{UqyAt_ut7jsgJMwaC%M%*)Yrklh*WN0y(2+Y3G zn=O7Xi;tKhb)ZS9u$vm-xvY#mGm_A0^OY8=$tHcni!m5q2moHzF0azai+}QF+~ARG zUH-fqG`N%fm#mS~|oY^qu;O+JmZ%WwAzTbdFbMX4rqD z{`YnymXVVox>X*VQ|gm5ds6CJ+Lx1W7oqPq<|rIU2sSh!>-tnjiF%{6c>jfyNu_`g z6%=~!w`_wVIv)Hkg$tF1NRKvpP(VWnbO{<^#r1NfC+Wc&*^))g!F5DRv(JWPDu}7o z(;=l*?G}c_3E7jpeK9 z%j!WcX&Z%GDGp$ep^p*Z@4wXU?-{2N3eeIT_$bEF{=2xbyt+ebmlUs`8>tL6$s78m zVhA;Lmoy%V;`&mZmWV_A_!aP(i?m0iV?>83g3Nr%0{V?|)iz{#y|D_YadXEE41W~7?dX$Gt-!SkVA~=0s)RQp@SL9Na7{sp>ZxR4WZ*uX$jYWFQ zw9+{v6JjjT=RRwG(A)Xtcq@lAYlt@V)zR!FNKebvbVyU)JoNBBD92~YzY9__Le^H0 z+=WhGPj@fpjH@_sz18oP+JuCXZ3M*LfM=-RoQktBa0b&2w+R!u*kkBZYgo<+65&o7 zGhhl{Jd)Ai1l5bJO*wF#R8zQITN5_=+q7oZ%uoemtZ;KAz0O0~o9mR?7~U`g*3m|* zG$}^U0c3En(t1ZZmb^rcC1EN#*)DcScq%PAXDIZm( zj!ZI6LjkkypLKOr#>@J&68g>>J%OcHj~eX#8o4yV*ddyt0u5<%2Itr~|jP8;{`3Sa2w0 z6zGJmkVY(vKcbs-ho^ zhf>regaXWEh*=PqfldVvQd~LtB%{!w?!4%4MubHdL2U#oKg3br(;24by&q69V#%+q z;9%jS=VC;xr?Wi249r^cuom3pW(EkjzaYjWXW@Sct$R(4da4adxbWbGYv?)HT@3_dyQQx$fiO2rxm<#g+X1 zhD?B3K|;Ujib+|XoCY_vYp~-%=euoD?xRo)_H7k@6qiyEh;Q9miho83tVi@6!C0pV zn<_r(RxrpX-=FeyJt+QQ_h#WM=_hns8nth<{Y56*gm#D| zO-_ez5jk2R9N+B%^47RysRitR5kl7&YCgISaNkP`zoul(B{U#vXc|6o&Q&NlJo6{nC) zwEBx=<`%Wtl!%`BrTWq9af!nn1qEqubSy^2Wz9&cwo}#ZjY$IL3{g7buRP6bLsIw!E{(v*j(N&1AqulmBeBvcEpIH6&tDps;+oSmHoNrm7)doVqhhTE zLSUsEs1GFMv8!ml=c)wK%(0kDDvpdS;;1M|=#rfs8wpm$OZfQ%j;FD&Y+x6XsE zhh3c7XQ0MhC4s#JgjrSOBEg1WhpD3DFtzC>7(0P1&Np+%RMi!?piTyh+qh4^BZAJ= z2d(BafccfeU~vEB#;0T(im#p?J1QG&Qq)ZzY>V`f$ec*a?NBiNA+sY5FYjP386*hY z&BZf4F*ma~@t2GgFPn#vlV|3#0rsb^BCoBm-JhC)mxhsMcW6Bc!-C`b=4z0)hoOg& zv43`+*Ui;UMpl-gh#@4%!zFgjr3wCH|916FZgP5>gn&FN6`yGuay|#0>kFBSja~m8 zMaWX+|7HE67YYr5Li^^{b93){{uU$_Y;&S@uC7^5tgZQzP}FTCoCOx)<*Zv0Wf*;#Sy$Y8iU>3LA&_))VnP1fPS zmpmy(SBF%8QbbCHAc3k0HB6cNulwk408V*q93e~>B_E9k6XjAU(R54Ww=_O0kch+4 z!9LPUxgBHtBm&QzPwrWfe-(DRBRaySWdKsN&d~ojjKg!n>Q~d|rRcP-5kT0KrArmt znZP~0RfG7z0Fiq|v4pE(sv?IjPFTAPRa&;zgwZ5;>5BP22t4JvQ0l zZ}jklBrwtlbx+U5f9=1V5NqSRE9iNA{4me2 zNmXoVX?J4t4;}pjB_0teIk|}Y;E2II8+k$;0(oNb%<9y68P)5CiQzjw_$%Ah-PM&( zYd!*O{o4EMT+p=QSxrYa`?*Rdw>98JB2v&xQ&$dkqN}@`c=h9Bvl#|xI^oj_UPxkk zTa@73JFz51Y#o4O=KH{y%V3sme3y}uVBLV@0^~cT@3I1vdnp;|a6h&_3kWijr|+Gh zcb;8tyO=0YR217T6(6LR!!8#p+_|fa$$d4s4U!L7FAV96i5WxF+bV`;XRl(U03}wB z&xyYgIhbi`SX)2ugzP5I?GF_9Y~NiaYOYEgiiW3RF_}i9hkO_WsLd7O>d{xD5_Pv= zQk!iE1bi82q(L7@42#UIaE{B(R3uiXE$H|&Ch@_{+F9=v?nAjGK4PLsP`Xox23%s? zCR1qyt4Gh>S!{AA}#uQ6YS(ryR6P;x&5C<<%N+1fz5ff(ipo3;8RM^75=aoe}C)Q7p zY{!!EB~ehIPiZ<0Sbq!xBSVj!xV(z1OwkugNd5{NCc^MBX^_&S!eUJ2LS0$V6k?s? zX#Dj*lRRR|EPp4(w%)o=bY+qFuv3QhhnQzaJ`Qgr(D`jW%-xCb900B)7F3X6K-I&v zg0)4(1|DCK$331VPqleE5%jDvS&iGQ{U)zuJwA03xqTW+YI}Kc9rr(5Y*9xVe*euMs;iVK zgtBS%_Ae&S);YiH>(`oY<}NJ`-E3Kfjl8{Pefe>0Wt%_Wt*@3oJV2|fC-g3hWO}95 zkCRp1Y@gO^qZ}>gKE{0TZTxl9Am)jE!@Pkoh?Y>>Z(eYNgWVCFR+w|L)P6e=@MqTl zGQWtcujuPnwaYPh`Rs#mf*2h?7uT<hnJpYcG#XMTE+*rep%pCdhGC{(6aQsDw}( z7H^7$>J1Pasp1U=>t_$;eS%2SVn)4){5kV`{mIdKBeU?b=YM%QFyJ6`OrD=2DSgJLELCt^BE_$e3A7^M3cTy`O>q`P*JMp!GNPHBn8+hy8ro*&{|;`W4xOnk*}f)NrTp4&jp@%K@RbN9l+)mIp^y@i;=ok@R2LH>vNQ5#vAH0G za47pB&s6R$&=WnAx;UIDu%cXHjz%;%ztbO`;-{i#L~g$7!8PI$J7_|~=Vo68fi7-% zZ}VjYq31--?AG!6>_Ec6NDKb)orV_vW;;ER%#lXH4xb8hW<#%WH0|y0Bv%v8q2n37 zQwB9=rTwb)9rQqq4DeiXxn`qKuB~n6_bkbvrMG}2{PR4FFs?9x`2sFx-hY6G@D40g z_jtvzKw6cn@DhTJYg9fzl}^*75jHMbQBF~MrSk** z35sy&<7J|xceQ`@xzpDEqyu?*QEUCKTVnt}+ow-J8X(WIP|3~p6+b^ur@yc7GpFOE zyeX*Mu9?Gnb4y&%!4v`7Q52AL&lL#pi7zb0L`L6CmmphoqGhlvPOt{OrN$6wZERbn z7k$6JqdeSK$oMN}F}Hpri=E?nV>DzK>!+XdNji!SBb3qMPiRi7u>HeK?mD$xR1^FZ z$o767%QXId_(C-Pbdi*j<7qMXYcu|S`fT+uBr=-gbt4*M?`U4M&Ff`TuShl1my_k9 zXL6byYxD;}2r%&I^w9Gmg@a1Whn87jd(H3mRs)@B!**W8a7DKCr>naw$w(?bDJ7?X zz}-XbKS@ux(L9cbqv}_AKNF|pFRw&@_t{qJ|Mjl5zjPjv@3N{@V{P^wYSyx8h%>pWBb9%Y6*t(%#)8&I8ZTgjt-CQtZf!r z9^V}Q@;jT`{d<@Cdc9@ORLB5W+XgjOA?KKwu`#W)(RJCHmY5vKbfJ^ zerzo?dz|-gENN&LGtm_A#lExgK4|&pkHCLNCE#+mfihRX$Z&}yZ`H#jZK-B~RCQ83 zbu1M+u&Se^(du zwDc~A*F@c)?n>64j{B=$V?Z$|^2I`D$`mMfCFIP83ZXPZjyD`7nd+ab^^G5Y9nozt zhndlZ)o03ulOJxUf1+GJ+k9|Jxcl;*#Lr2zL6G}9UqN;0RV6{Bx;YA?`JNF%a2G^R zWgVS@DPwps9&mU_USH>k;D(N8hf-jp&=U%z9s=U+?R^qj0XaqO?5U(lK^z1J3pLi8 zO=B7=Fj4uHTYr=`yUangWq2L7bpkaAeU}RRzLb*$*0Adn{TEh z`W4@*d1BH6r3aOP!9t%f8U!!XMeymU?POW@hzPchH&vw_tSr$1#EvZUgHh--go+`c zf{KQQtE+{qA{7m+9{%^5MIa4qdeLSNc*P1a$Ob5*^acr`2{AMD(JFxvhww@;w8cdr zWoqHQqLl&wA^fw|GE%(~HFSYG>$oTAh-rJ9zN<1|FViep2AXC%cG{pL7m@Tm{X53P zj9!5wS2JP|e-q&WE^1gzc{gD0OUIJU-AJqZQf0Kged9@65=OFNsMQ;{de=YK$IURx zLO~72O26%Fd87o1Q2!3Mb{VSpJ)r65r{NRccuD_R5vM)p*AK4un>#xGw-YYS+ zxP+>PhgG*BP`NSPo`^cYZZQ9%+-i0IS&f^VLyEUhndjx_la>CgG%|{Yg;Q1`RBx+c z>FvGep3{h8$Np3G;y7WKNX5$LTgd28%In$gCi&iH8QF{53mJwcnc9mD#ajvIQ&$Ep z6zlc23#lDmMSBB;O_yV3_DlhHb$Lw}#0lgoaGI;3yd>y}%K!dQ_McsFMoY?ypKoKC z{#m-wX#geJhg3r{=6dWMK?X7)Q$RPe(d#UvG_5JVviLs;^y)gz%tgJZi`mD4+KQ-95 zUT*kWYqPf>cxk>-W=>^odD-bJBan=1Zaa(bfHIFUIWV_)PhRM*e+g4!Oi?io#Q=Sx zTm+|)jHDss8KUCs@uf`lUc`H_{lyJ5n%YplG6|#pHQak0jC(4%6doqUI9oMM=&*To z<$1om`v+Hg0I4Yvn1{Uij3#?zJ^*hkVs z;E9#ZH9AF^y3dgWpf19d1{P-^LMH7(L&*9btU$?16-6)6k3nfH z*T?7;#}+0a*b{`Spb$cUN58~X^M!he3-HL>{Rv^yjXtXUB+5x99UuF_;xUl~Bwi9t znI)`}%_8Q-#?|$Z>T=?0i+G3Xa(hMbGkh#8^2do>4l&&A?yV}BgYb7IVKMBQ4R;hi ziftU1H`$dUbz2D)(|e=qfEcZLyXP+&W(%EC-Z&~*$tfZXOhF=&K_`Z1)KoUzTxvKm zEiQM_3&L;?xch1FNiVtfDrz`KX2PT1^m)BvaUu+LOCvO?Q_hYCc zJo@@Sjg#Q`sMq++UjOttdMQUE{r#`o^4EzGY>O$%d;Q3rQp9I->$V=)_=n&B;$~OMOp9d9 zkW5%~)YZOj(-$(8T2OMI;AXq9q_)O0QRS;oYtt1oKIm~4?ycC6l5)bfR{wSWSD|z7 ziQK*b2f`t4`e|HO;ik`ACie=dQoFByaoYzZXixV|=M$x{Cwe`#ecM;g<;SPcJtI%g z6=7C;qr-1{UKS15oVUSCs~=p4=K;kfA}_aM50)AR!~RcquWWjaZ^jF?X6t5o>x})k z-~M$E;t;Ik>iwxep71b3fYbii*-zzj0kD-d!F!@hQA>8huTZf}&cwx5F6>e>O2P(g z#jgi$Z?4bQ)C`BTF%LI#ef;YNJ>8sp4oo?EFBB{vtZPGdh}SooJD%hz|6PM{b7q^S z6~D{&z6bF~JlR!}527cbJYJ4B8Y0=cG_inCW>=_X6gFB+)+f7>Pe=IXUg{vpC+=$N zPulyGO&{~e6nVHZxdA<2?q-bT*l&laMZ18##%})HWGc-H@@pV|9r{9k?cg4Ex|<#$=Wm?m$mC zr_@JidBpY6iyrx=^=sux`uu+#&1zTHaA?{%6_8H)Bkm=hp31l(eXtOhUU;meyT8Az zsw!A%IliPRJWjnC=aT{+L0AsWy|>8ys_&<7`Pp@w_bdDV z4QFQ(;Nve&&S~>_Kh%nOKXlK}b${3?2$Y9co2=c+HL5n*BR@LMXvRxKqdF0D|03vN zK4VEj+unF1Nl5!Xj8;0CbpseLNtsqI-541E1{UPiv-5wqFbMVn3RdnBCBka(MtJ zd_C=BuXBXkzrQ-JXUsMlj8OVFs0My##hEB=^$*v3bN)5?trZm_3@gpt&83^)+-2vl z7J|>DZE~}UMP`>~MO=;4ZG1BK7P!_O{#+EP2JDUh9{~732fx6h!0XOl9O%o(uXdFH(-QCht#c2eP@7)|~xj|YdZf*5% zS`})y1C%R(?bGT)d5q#Z4$|r=)`gDBdC-Tr7G#hsYW>HbwzRh0f7k63bFyw&ykOhT zFHW5~H+fZ^SJu=DI?#kJVEwK^8|sQ&A$G^fOi{;W5ji^{6om*T0$ad! zd>vC64uvUk5D{F|bjYXj2KxtEDo@sbejq&2dK?ye;gSJBUNwwmnKUyT*MYzxK}x|v zE&{f59g?+Gz`Q!X5N_L))u>;(6mke<0T)HdH14Gk*&v#%Ad|h2D^~`)FE++c9v6|~ zSr4$N_%T9@je}k)nuf4Abxo8-ufUOoZ3%LSfN%w{eeD#eVHDR9Qux@U1>jR0O*%y> zbp%kiLS!|P1p}iv;S;e&@N%dT5=v#T34 zL_T$sqHrjnsT*4&2!kdl=xBf&W=C*QQz>qi63P>rIzfx3J0j@R!4DSUc&7u9$Ucjh+pt`CEEQmz9SBk3;RR#T z#>ZP9-uTOUpY8d(&OIL6z!6-~gfN|;K=ve%EP*LEm02)57t>Y3Iow7CAxZ{$B7w}+ z{IEeeu!6QV9KC!Awj+^-t4p~ibvk@O>NtVIe#Gt4?Xor(X8q4t&v@wJqJtohXjW%TeG5M!7O-SF$lsp46aCAwdb?l zUu@ld{?r*^J5t)`QH&4u3w%zxCM*L%QSjOdQaCXAh!L(dxGcAY>==u(PIL@RIh^P9 zcxf?tAjx8!vRwxo>Mrm5qH=O>&ey*4(6pKPp<(VDi(oad`FsA?uWekvEtp+s_SL_* zFEb|-{|P=OF80B%-(9w=VqfWj|NPPK_v|_lpP2B#6L*gtH-^4$vS4N&e7EENJr(V3 z9q{+2xxTsfP;Jtf#APe5hp&LFC_jAe^aa;ruR|;LRdsfC7Z&G5L`6UiO_!Ttw=pF( zN%>Tn=-3#zCKpDdxw&m%K)USF`pY(|tD_V8xOeA4xMvpbABU@YJ!@SU9``YjGc`Lj zETr!ICH|&g;F_E236f$Q(xfpg!#443k@d5Z^G_V7FVH$VKB&D(bTk4GQ6Z}lDU zb&!*NxkVAlo*gI(bY$SF_sXC^WCKTVK@(nlKu81(fy`ZkEimOK@Z~I+olD6a?seP!o!yOxj&`+m_H}fH zC&N|to-kZLRyGE@lSLQ1axVABGY61MUsyR{tcW1->*GM{X>S%S-g@+}g6^^G{!X_4gY#zc-_J&eX{T1B3m0%67i|>eESysb7EmFAMTZ zqNAhx`};b(y5eGDU7A+4r%oM%eaxv-3KQa!jvlM7sM?>No~bBH4Pk$AkhzU)-DvoV zt8;s!P9~ajbTyDF5QJqwE;IzMt>92N9mP%#4bI`V@)+cQ_=H+{EZ7F75C}~kOI~FG zU^+L~J6IV71AGRKTf!@Wq1h)s;io;>06joVwA+Th=2xyxO%xOm|*+_+tJsJ5}OwRm;`++jUqPEpmtBXGlY@yvpOfezrnUydpH6T>5{ zaCqVkJiXr02z=iwe4C_)>7ymD2OkAsn62#c!qDNQPa$NGD^P}Q4Spxu zd*I;V+T-`$d0S6!Pu=DE{3(;7q9R^+?ajv@x-TXe9`EiEhS6ay4}Nv5)a0bz-hQ|n z7>9zA_I6v>tptY#Q+c$8;LdUQBuO}h$OcUu8Jt5P7(#YVBb)O%2(joIx!Httlm$~e zteMDEu)&}X%AC=`L9v->apTkegiA%3anE}2E?1y=I~}!T1sQQ9q$gqV%E26FE(*sK zSOQPPaVn4#2$YTPmiE@;XGe8*p-WzqLI|x=!_pE$95zvTII_B9rsGONVlsR=TkX;6?K?Nu)KtKBWNS-P#i8=9uI^jEvS!&$w*z19 z!oc9bp{j$&kJqeRc~9optf!y(kMjLrT)+6{uu%6uEz2<;grnYA?@q!IRGmZDuAfj-|N05v_m=Z)ea0zE2JbvY50*%6Drp_fR$cR$SP`{FY4OEVR zb~{qZ_7$W_!UaT4Uv+%t3!84fhCAd_7I}lrp#suq2arecEjF%?IAh4nIpiwTIaHhG)3|Uh6Pkf7Z$!ldQY!6Zkilj9`vyw4mbSIF zcCiu0@0QX3n=ju7AM>z~wA)NXoj>-~?dj@Mj6^w?8hfu~ek7e~6Yz*k1n_Y;90 zT*Y?B+FL){T>AN^rNwin-+Jebv5B#~pw#rF%&f7ej-NYLb8_aqB6#`@d{^nLg~c&3 zQL`7$c<;53_k6zp&IfL-J9Dw4tutrRxYV&J!LBb?4Gsy0oYogUYE~@Vy967NK^D#H ziq4PU-VD`3AP(2<_ujJh>vv7cpCDgl>F$$zT&z9Nu2BO6@k#N)L9_+y@s5NgFBm{@ zIq~U6gyU$BL`eosUJ|EK%?JvgHaOE@zy8L0xG42NMb-EI?w?1)_+Pow-Q9brs`@}> z)q>eGt%oeRUjB(2^7`uf+rR#c?%uu+H*Wsbe?RyAZ+(Nd%M+8}Yro){+%8)GWVE)l zO_}1uIp0Z10bBq$tpOCo`aw`w7>MY47w`gfiLM$>3jh`^yr>!20#KZ5UY%ovvJA6d z5U%V3jTqOLnY`nItt1O;QV6kO-01vSt>~X=`u0+;BNPee9%( zQ~2pFIW;XiYvM;6-nn?`0#wr6(p-1(d}d}=QGSVXFA_L#H+N;#f#}${Staw+#z6R# z>cbVMP9L8#DL+tsaUYifpE3{>CJPZ=;Q=o|m*^_tH1#8T%<>Zjb_@_oeU-ZU=bc9U z6hTu$291x5kjvJ9Nj1!?!DXwgB=G$^PT)IP99P;vCce_pWl^dj$Tc}lDdimGQrsp5 zGPu`Z`^q;+kqi3nP?koUV!@zT*+3fL; znIj-VF2Sp7>hyIu=*TV^(WO)1u4x1O+=oq{Ju1H8(URB*!PFB$w~ncl_`% zcz~=T&=~}^QF*e#HNZ<_=>9Bkz1Q~USSJN|g-3?Nt=VvUY-?*Nib-*GFf>T(M*4j{f${=C2VfD=|L*5JyE7NgUW7vc9uywR zZo+lvbo$%LYD6u0?Ey}UL^d!AaB}gQKusyA1bLvrR>sQN1q5bLF=vJ}qefl2-0!+W8_rnd8H*;(n*tN$_5&{OKy-GcnU_32PnwUWPFe;3%& z;Dg~AZ#OC+(4?j0rl@}tjSMx9p+YX^F_}7j9S&Nu%Y($kCcypL!69MQH5Hwm?Z7^M z^2p(uLlY9NLP&diYg1EWY;1gT ziuYH4u#Ox%cBJCq-ibL=iVA1IR%TIA@yV05`^&#z6f~hWDo-{z9lSJ#vLUMx!bKd~ zpijkMh>SRq0uj<70|Ti^1_=T}9VF1p5RO4=l0j@4A(yQICj`Se!Qir0RuXiA!|8%D z(KQf9KznthZYInFNAu)gxLMrnb-vqW)S zs%9>l17E-Bt03@|cVRyeUgp-So*DAAY><{j7=EH?6)UI@Wp` zNmOhk>?pwnlJbLzd%Al22l^vpqQXM0FHnOe=~|nf6tt>^+B^)mB@nQ9KFdOP+;Ky{ zIa6>^Q^(Qa6sR<)9{<9XnzkV1g1nU4mTw@$#>K7u`aP#koZs-?*5Ca5vt!azO6E?3 zuT8XO(R#tuSqA(ce%rip8#O8-JPbChLnFeii`wvoI<5sVtt2lhCbFWu`pDtqXHK4n zz034*aN~b)T4w5`ya}Ie*m?Z$$$h&jx;ndyN^+yz7vW}Pj0p`5Id$|*Z%^-#Ci_)) zw!W{=x;Hi@EtNm29*|pCFD;yz-`?7B>3sd$FMaUw+n+98Hh03*?BF2r`v)yOJO*6V z_Vx6Z?LG**qS+HO;j!T2YPQ&pUQU(-P9ie-gmN5`Ig!GeAW%ux;mV%yN{ApRD6o@$*?|d^YIR)1g zr2vCcc*$~x+lI1*j<`MQ--4jfoR**a4}bIRC%^N9fBESzC*)+6OtUUXnO=|wPbzqO z{l^d9e^+8We#6%&_@|n;-u)P|FPLlHxHm#C>Ke`<&Jo(=!qf;-8JsVG&=kl-KtmZT zaG)uVJeCJ8P#{|llnHW1bA_%Gm{QmS>P~Jt+!Lv?Xr1^Mu9CFJ@K}S?0Zd9AK_gJg zs3Vt3>7chxayjAH51gJMuQcT9aJ`bpV_xdAK405HUMEr{s3e1+68DtLEX;1r(a>iY zoGxJ+9nJ~RgutOnSwb)!s(A-_T9g7yHK^(8@o0HlC`g#l90gCFnK)re-qie)r)p1} zteucE<@m9h!7GE)3QHm*W8e;J_@fpP5-vZT52gcd=;-gWE(L(cZTI!}!2{5OgVh@v zJw07jH3!aLJU@T_^)NedSI3x?bhu0Dz@dH1m){W=Yh9-gFW0*p=pgnH02~Lb8Gd5Z zMYSe{zp%*(;3WrL4j=Ib3-7=84=+j+%}6_Ox@9;!@0U_;}K+# zl|zP)K&g-yUjEfQ`~t2ixIYfLM#4N3I9rAxgK+Y4n?S(;4!#7OH-UgnGPu`Q+rdXg zg`_e{2`YkBkRjIKLM8YECv)(x9{i;LRB7RiHFR1*zLEo=&a1;Q#NgE>A{=7{1*%q{ zQ4J)VZ4Qd+nqhg|F%qzM7$*;isq?dpF?19~7{k_(9Pi#UWe7E4m za~I(W#XKA?@Vs#59PCKO#KeZc-i_;}S)1aYcA-vaQWq$Nj5u)-M`uuKv%Y!6xR=|u zNp%FkNrwyr+1{m_2^I+tka_&6wfWPgRPL|&!_Qwled4V83lnV)Zl)-kt}^mwz$Rnl z(drYQeY_KP3E^4Zv2igl{&4s7ocS~08Qi;f>@VBBAMSF^oi+uo(nXp48MzTrku_CE z;X5Vaz4?$f9yV95ly2YK(cYPwmOO4e>}*PJ*+-9xPl$c=skM32raexUV0-UEaINqSHiW2Qzqx&e+9Xd*W7l?*$Z{s zKi>oUdQ=Y?ZUF*J1{9AS;Uu9)uD`$T4p@JVt(UxVU|p}c>4xt=_2mA8l|T5$f4y+& z5+DEf_Zwmh)#iX$ObYf3ok)INR50s$`TnjE+;>K#>rF14<0H%Uw5|paAn5W z@x?P{gDX5NB04InufM;ykN+d9CEnZH1LGVUV?F&QA_BhFJG!U0ySuByJ&azt&YnA6 zQBmI5czN574gdI~@BQ$f{`#HuuQWB+*Bm)?q^8m;!g|5^6Aj@fp;lh^f5^csUb%!| z2TjnY6lS`PCg4;gl_wA&i;PfegNrz71CI;%sXqkNpEtQ2KeD*ZIozhC929KQAwwYe z)HL+YC!K3#VFbLkBL$z_bUqEvnQ|76kPaE#`{!oE+i!W3&Zj}3aQYO4;X>0V&Eb=w zsy}T3wq2m*N5;^&d4C)t9XRU&6J8Eaf54>JFSFt5ll3F8MINm>a5+p34^-U;XXNwUxE7 z=LkDTS(CEsFI=pxs_AHR-_i>Y^{YK`tg;rK5}q_B86MQ&XVG8QFd9AE0-ZC>QlvwM znI91E^f`r(J+(G%Z0goecE0q7x9aN}h4-P>+d%6DKiPji<{qD&o}8S})z$OS+nbx4 zTH$Le@xoh>DOmv9k+XJ8Zq^30kJG`70sIM*>-=~OU|#1^wEW1xSY6qG z0~204f{>b@Dd2W{^d3Ea@{6+lOBT(|8ka%W*P`q0J8ugK33~OdclrI=g62pxp-%iy zCmtRi`shRVJbd4pja#<;>ra0J*Xo9cg?<0q-+18OHSo2P|NQU&`49i})Bo>Z|MM5W z{o}Zdj352*dlPfAl~Q~_kl=76m1;5wQ^P@O`kI9Tx!Rwb?KPCf&O2usWjr6U10`rq z_a*nj-=krU11)1={R_>jJPuIRpSA$!{guXFc*{UaI9EW|5r~O~BrN#9EKR5r|I-!( zTZ(zP1);&ARaFP8s}CGKUNsq>;+HiM0^kD4_}KWiwwBA6FY)-UuFjU`#)$BUgv1oU zqoZOI6Ov%J@j~4>Msttkl`BVTD-TtcPo6LZ9-9*hUwji8ot-r)FL(OsGbhT*cf-Y) zEJ|kz15QnM7G0c(Q=sVh^J&vr#kR}NnPw@{A;ZkS=KYh5fY*f*fto6ZIu2>_xWJzx zzhh-JTLeB@(k`AY**}6qeH{&PomVIcXt?)B{UQUIIvoK=5djo(Id4)`JcJ6PEQIJv z6|3_8IL*bm?PNI-2MpoOkgybe-_-SO0n~AT2#JsZ)I7Q+S|{cII}?!#F*8Fas} zC5Q)!^A7VOBpU`37clqGMAJ|cf;7S@3Plj-HW8GDtRvHi0|qJ+N2rQ7DT>HC1J2-Z z%I@jyt*$)=Ulke>6a@d~@sof4yS&`VVPV!&L6Lyr`MHyS^^<>wg#j}HQNr1u5CMl+ zc}#e<`q~uJOyb}9=GX3Bvyx&{6O;becfWPrqPcN#aj@7RazaG=!#H5fR9RI@CmiiWG3a zo0J#R@ub|=>GMA7I69oBSV!;)cc|u&Ca0!lFS|wpxeDnV4uy_}xXvq#2`w{Pz+@m( zx7Y-HIO|O)37S+DcTr)Knh;&7yee*US=lCi*+qyoyhB3_zl4*Suts=OCX=LK*fMnl zo(?q$fe_&ZE_iXA$O4l$IiGMP2vZ67YfsJ1Ej(0N*3{epw`Wf;nh6&?LRfNAQbuO> zXP zOIuTAb;XrYgLmEi;IgH+yIrI9mGAoFbH9SeZP(SEb=~F7%i}ggU=x1w(GZ#>Oa>7+ zLN*yr#*@fS49W`zyznOmavKr(C-0-|H6c!p3sbze&!L(_nzy#q5>5^>;Tm$4N5E@4 zQkmQNIU2%sD`TW59u6|vu~PQJSCDqc{FB8^F-+>80>lbiZ@zYc_(~mzRVibYwb=^F zQZfR{TnB7{gNQ(*o{|AkWTWBh3N-U!o#SyqS$I=83W#jhL@6Pb^bKA)1T)GwFOrmH z8o2^x&^96$?Ct9x7_k0Xvf*Mqd_iPVYVxGQsc@Sr=M>K^9yckgd{^1&W2eSuXK+_Q zY+w5hAED{qz-gx`jbsOWN;amSY z+B)FE$z^vfzxCc#aCIb9t{f0Uk9@CFeA~gN-c}9CW||WY^++cLB03kv4#x2M7o|aE-AjSIQL}jsm_kBqFL}{o1r_kQGD^u^S|Hy z&G*)AEc?8?^z(Aq*?})tYHezT-K*udUk{f?LM&X;i)8D8Hqt9LF8YzbSX*_lcHhqO zr+@MCKm7E2@Z@e}!6laPAK36U1*xe?fUmsg=J`wK&R<*%UxMjTl2F80L>72diq%o5X-)Zs)R!}XSzTD}<&i|_AjMGW=W#yYG?(2+9=1d^4R=#p_zkU`QxL*^*+ z3OWv>C>Qdn32S92J6nMsuilc^7+<$&E2HH=paF;>@KZ@ZDcNP_D+?egAr=e3Vhn*o zl!58o<(`D02>8P^*gtsb{Q17_9=O0KA~I_ASMLU2-i&lINnrjI8u0zw^+N zW%Ijxdf^KtM+XPR#7BSkhfg*&wUQV&ti4EjJLj7C+t%DPH4iQb$cl@%E&-5Vi*K4A z6&c>y**$4$PBMFl6GT7y&HEPKFdOc5oqJs|JQj(^!gZRr-@klH{>0<8r<$AE;No1^ zCyhyniI0yhm@#F-)a;P3Fe)T2KKgI}Jtpbs8&Q&K$Zu}8o4#@jExy6*Y-jqim8hqSac!)PzP{?>c% zekdu)dJ_)US57Y~e)4O7`ReP>Z`kzSwjCScF7GRY{jfIfykl)beDcv_)pxAAA1|e( ztZ8XuUIjfe002M$Nkl)zXu z6wM9W5zq}Fjj&RIM}aMHFcvx7=7piH^a^S`3Ud*{%OH>$@V?RY5{^X1@R&-D0Lq6F z#pBV61c(CMK&B|2;Ure5jihuq5w012jDXkGkub#Yl2vH>`TPNsH{~c9L2*NrEkV|l z5qx_sY;}3wb1Tl^9u1%gfgYdlmQso;@msDN^tUow3X#Ua#O4kU1?%CAJ$0o!qxos(T!w%!D#j|H#H=Ewa!+y=o>t@Yd zII|ZXradqa79JKH9ArH44C%?MNLVd00V8R3DW{J>T{v||NkEW9h>Qq$xl3pg*mAEBdey?dFbX=lN+U*JEH z#Zti6M~=_FRV>vNppJEG!AgTKxP<@UyCZwLtnYeEg>Mo0~QtkrW zCmkLOwm}xWxO9aO-lV@)*4pF_x4mFiNn`~4i$=txiP;P1&BjY*l}vZUi2T-db3~GT z`9!h%G@H&xV>nBThJo^Ogx4SwkO>@gaWaV@sB{qr4aKJdSlce2yc7`;*0azQ^p#DW zK9Ow+#55R65dL z0!iTs@hLaoxC(X>j~_qM)Y>%A-yabkm6n#CpEqOt_#7JIn3(t_*WCiwLe?HT+}P0A z)7u#w5}KTlTv%xBGk*U&|5#8kgYNZ4@vzYF8?IlTkdTm(J}xmKan0@b<>h+4D~pSb zojRqED$y9&eVR?@qv6xEYLXAF5bhOZsW8p~n^u>yU;=gFq^2OiQDY^L-VGum98*G* zbPa}1pF(U2#55R65dYI=HHTzr3DZ)0;)N2}rr z8a(vFqoR@$li*4*5F8ljZ)$34YiuSIkZ;L#w6`@jG=L`Hql1D%!y}@j zqoSi@LPPP(q((P4H#J?ZcSR}wCc-4fB_<{%4c&+bG~opuX>V(5tiROM+}zvK3(rqT zOos3ANREmc`iCRr1FfmKv9-~9qywG6zy(dXv?eJr5xawA*Jgl_$I~g^JLN18feca0 zL7Nt{z?IjtKoFXEEC5J?gcziVBdCo~lWIf#tDIfesp|ceZ$0zIp}qSfq9WEja^K`a z*LIJKZQQTW=>-y}5tIIs+GJ5z9yB^E0`83+x;>Ul$b3qTG-~3kVUx~~ zYUn~uK_6WtNCsCU5DZ%LWDt|<6G#Snp-F1wB7-=D_-iQ7D=snH`L-oEtVy1D3H@AO(Lq4bu-dXwRQJD?BVbIGF#9b=ozVC#dL;3ORIS5uq?j27{w>5obgVS{^i; z4$~PXU8x}Rg^WTa*u*6Vg5>4mN{|z*K=_~c62uk`ySK+P7s^Q?5Dc0)GAJ+CCy)%v zNNVIFgE)lvXBbi#T$sz2;IJfld@#^?A7up@(G>%|ATq6`p$}hG4%hNKI zy^D>FpINe?q z_5Jtv_w>E@*17|G_PTl@9MaaYJ+&Zz^?hrn{U-3yBO1PfkmlIcLtSg^N;?Q@_};ecg*MdIG$y;Lwm; zR^5Kr{SSD-*KI@Ae>HcSMrLA?zs29$+_$rBA2`{3($mNRF*-Y96~OP6$s>&4dt3N ztUVW2=Wb45X%FRDS&d>fkJrdKxuLvbH|aHH7RH5?KdvCX+c`X*uq&wsbj6X zEnI(tU}U(MKQ@gg1q)pUf`AZ{8(0rVY;6dU>n9zhO8{XR>uS>ttbaK6g~xK+wx>W4 z_`Gaz@#ZjdDb<`*=mSTX1|K3dRP$0$da(vvF5DI}$gvnA&J*fG)RaR2LB1ghV)@CB{vrYh zH!{P2LcxJ;S~c_ZfmN5@95|Qf?~|NI$xIx1jtkQS1Oee>vSl4|=?tYZ>BvcCzRaOq zs3Q}6#@Tx?KIYt}qy|Ia^ECx7-b4}3=e8jQ12fxNPzsI1Pc5zwE^;m^=KrbB3%X<9)7xv7 zzawDUj>K~-C>zedlouw6_y{VeO(&%!len*juu%x%!rN`1x3S!B*GNZ5(cRU#W6PG` z{_H0g&z~QkGa)l83wi_>7@j$E8ZLXhY58(^bQf&%UA|NY#S9D%He9-NvF-v~ayT|~ z9NvDFmXQwsrh`2_dbs-ib?XjSRe~!rJbcj&H>Qsr%Ln27xwD^rx@pwlAcRLoMR#>} zT&lm+-PyTw>z37PAGqhihdR5uFc(C%wYJt>I1ffzT3TE}A{g+E@m(G5T#)SgOBX+Q z>y0;FdJ$5jj7iJN$_|T4YHe<=C@n@d*PQ=S%=OU$AM8DMUA^Kp?UpLQN6@g!2g}gF>7b zlvfT%myFA@BqF;2ajv5Zr&-3Of=&1-tQ=P!?MFcpmZQkU4c(GMj=@X%1`bceZR+U| zh$l+xPvNj$^x?o(1(=3h)HR__4GyzDqz9U2Whe}b#0m{HIO&2buj_|~ZJrvTcc397 zBy5%;2LJr$<*Lf3QU?`+m!$n;)3C&3To)mNh!%(%_m#V zBIsk+1kjKI`9Po$IR%CTfyg>R5mRZW+C4DlOs;w!^A&shl+v*4e)NvvLOao0duvG=m{1@cIQ8iYP%rpNj~tVKxmzF{F$uUpNx!Z~#Ma z$py(rw=zU0r6kK)93pZSK_9y&fDXm4(c}>@ZAZ%N;Uh3}@*EFeHsA{QR+qNVxW4jT z@TIzoyLW86aQ5ud{$1+Sy?HQ z^78M00nu>jUfed@4aQWo26}OGa7#kK6*3s2@>`2Y~Z?4oc+poX2fA5~fH{L|INTUe+OiE5Jo;7>&RO^Q5zW%?-N%mBUbpnduRigOyuxXC z1}=~f z6arbp0)t*Cy-JWOv^5-t^D&e)K?FcE2MBU5+Ynm>a+2q##T>Z_2YTBj$Q-y#_(3MCnz6wuy)!hj(7Afp1&k9VfER z01(GZdpiY?YYHy8+27OFRab5OF?S#4K84GIP7oYGc!x01(+Zifi3JGR zg(h$ifu9sb8D)fAoKKn#`6xnB)P!)dDG?aNBm(zxo0uFVTksmHr4VjFakghl&Ielv z5t;%S`2vU|!tJPL=F);MNKkDM~+1Ve#&GS;X33cO+tu9hOAFf$Uy?x6bJ?}LEro$rjjOj z4b>{cedT-pK%<=PVFCmpF(Uj7fHp@QiVMoNQF01F!ImKKuQkmhpxBO-7X+7(y)wf2 zWb0-Q!gVf+RT}<9DWwpE9TeO?iKHC5H8Po9{ZYkjRGysRQ&)F9S11XweCS;n~WSy|(=j~+hU z+S-Cl&rj)dpvT3m^Jh=jR#(F_aZ2XQotl?VdyznaTf%3~oj+~X>^EP0{@CHeaC>-U zG<&SgFeZcg&I%Dthw~zaLEGJNYq+eIbvg!{*9*3QD0JKq>X4Kh9{R15;^~AC+a|9~ zr^7k6f(C-FenW`C)!x&6xlR3@0EQ?VmuuN`PH&&*I{}385tq<$|JcjB;4BQi=+2ZjcTchS-kI zjw^$M1_HQr^T?xcjX7cRC=|z7t{^NqkQER?gf~$Kl{t>KQYtt)X;kA$*p8$>QgWag zai4WOA5sJ!5darHfQU>^6?~qiTz_z6(m-cJyR{wZdZCOtmq%~d)|D%R{oNr^DGsMz zpb1T=Mn=pf+Vc6jlBF=5){KrI<2uB=P$3=VQdVxg2V-+bCK*Crj++;d?3EEg9Gg4^ zCqTH)#c8tbUzDsv2nb?dY?X&kfTVofbb`vkp+JmsRI~064|U#|!51E1qRK_+Ow!Sz zimTm^EL1R8;FUv>$Z;ax(b7{{dL|<3B1S1YD2OXL4neHhoQtHK)K6rCVXhxb3REsb z89JA_JYAGS+*ffN4G=dEh_AASYHg()mZVEjccPKIc)v)Q6UdgmGD2i-@Dv<}aGi_O zWZSpEfNqK1uh!Nq95ClQuNeS0Ac?`?V+K<>lq) z!$wA4(X`{WHT(CK(RL)}rg5Y2p5C6Wp6*C^>`avPcy6y39IIY!t050aGo-J-n3|e8 zE;~EK^^k5Jls0C}*o+L=Fl}kJUQy98+)K7{{N-H4mksEVz^6VE8M8oBb`fXDaKE+k zjF=lCCb)z(CL>5j9BhHlrN|zpVa^2waD81poMUS{6mdf9>h7#Kb*%Qn8QY}8BwTK3 ziU^GyrU-_^fODPt)~3?hs^H*}eU1cVY-nqZ2^%_Q4P$F?QQ5R+2nm#-RHhLluZ&;; z9H){n3SC0)h)-nM5JF@EN9yY9TH0EJr7sHbNrVg9+I}cQ2+5j9Ko-YXa}nyBnL{oxLSnRUREQ>bDmHiNKu?>8;I_=U+${Uw z&e$|>cm{s44?tiWLR?KrP(9| zYv4Flg&cGVD}ne#mJK09CU69cmt~!Z$WdHbXFE~mv5*CeM9`2FnMwkIWbhQC&HfCL z%gHh$G8#NYx3+-;WWmb=1x>k$Ynv-@n<_y=b_s&01|TuXASQA+E+b)4Az_h07f&_N zSbD#^B(XpmhFs*Le8aB?G|eKE)J&t2Yz^CtS3$uxg_7XYP*689$}oDwl#?158v?q#A@b2PcjmttczY9G^9S?}1o<>0=uB-k{h+q6AtggoT%GP>ZZmDaX|#H6@b&l`PW zz}@08(a~`EXIFQZ@FX0#iu27^UJeZng@YCOQws{F70(sE>c791C! z9vvDH<@8Z{j*=%t#LgL)S5tqwtH0M#!0<@p@_>Ut#+1ObjBMpwy>yti^Nlgw6OPef%gct&eqckXx3?
j`*F0izscw3BYgH2u-0<7C|gkNl^CM41$ib z?@G`tMP4VdLS6{*xkt8teyAX$VY}XF^^!q)fY^o;8JGsXEKY$s!#1^?5fK!OPl=qf zBCBYTb=Qk99lDHx!f~#JbwV{Ly}qXAYC?9vPA{Ua?((WW#afx66;iSy{ONDrFJGQk4WjO+ze{gv<(t;!|W0nef#3HcdZ}QZ{1kV%OtQ_(&cOr#*n~{|4*<$cw zPHv4l&)n?nw()@+_T z+c2)ib+oCro{Xjp8v(7$FMe2Ueor)|dbNB_eN?pic3CN2YE$H#a*bqnH?Xaxvz+UD z&~q$UOtdjbVOczsFgS1L(q6BNK)|*RXg`faz6Eh+5lA1PokG65|2;wGD=0mR9!SV&G)rVBhvC! z9a}0}88s9;UWvR$r@&fz;qe3!XAP0O_nKS@R<#@rr<|XgfwzolX!f2X3;hu8BfEzZ z({Bi0D?);}PLkI`4{3Q^fHEqYdc?pj1Cfem2)cuNVYL=E!mBMyN=iPp^G8yTWFZY& z*St}DM5N!+-fj1p6umgJ=Y(TIFS#Je@78I%)R?NP;es|HKdZvNn2r(6y4F%N31DPd z2@@G2w%Jb^(Kp!SiMhTkLX})%7tdHXZr=>P7}j)PlTb83+Ubq;I|XRwbiB@A3PO<~ zk^-}m*uhn7uwUB!gv5YXwo}{J^BWBZlSLQqVZVgkpYTOGkh(*lCs2C?>qndZKO>Yz zZBI4P(JHluv1Bf^yGIa6utsPo{0YY$X^_+Crg6EMN4(7;&fbWX)lB9`4&s{Furvo2 zMztx7`qBj-WcnFIn)ZZ7nKNUtsnVSEF^*bAytU~vgZ8hk-H*)Z#Ml+5a6TCG*rP)A zu^_7E8;E$QTfSjC9Qg^E&tWGzb-%ovt#CDh9X4y2G@fC396!;Xajz8|BblS*NN=(c z_CyB<9yyH?j*;1T?g&51*Lfsj0;-zxh{FacDV@x{|iPx&!f9BqRp#2HD-5!&tC^d zEbPYhfR>Hbon!bZVGhkBTxbh%SD$*RM*^zbp(&#aH6jkF!L4CKcKgr%crpjsxr_-5 zBYaO_)lT$O8T3C%))Wt-7H8jUn!?-miO9D4k7UHvRZn&ipUTUj?Rn>%q9A_6!cQMR zlR{>ep+s{zX4xKhqrOjka&%rd2|Nl9eZBJtUSEFeqGA>xQ0d~PFY!A^)WADJHYs^T%#>q zR-o1}B|8yLtz+|Y8yPh!s^Q?nD95!3rh!U+EqTVa+4kk{B>xSU4C}OYd8X>JTCCjJ zzwJJ5M|KcmJw3yh6@fgon9DN>sGRpesQ)bMWVj&#P;mO%617!@8ZjK!)3=st8pqHt z11fIT%ea=RqQV`9t&Jx^Q`E;=e0G4{koGS)B18nydyX2D7lRi#URF8)N?^DW+Nql$ z4xvc!REQ^2;5kBPFBs!VRDEi;1VQ+j1grDh9_{1B4#M+-AZn6mt|T&H9{FP(`2)ej zMuwky3;wfwp&{zxScg%qQ%Puk=l#WS!KX}41?au5HTf2B>=P^vct`PZuPHY)&g{mq z{lXPUp=Q&{t0Sl;21*fNc3 z#S>xdNGwR;75do=pW)l#=FwtyL<3iF^dk}4f_X!`A-I_$w5!Emln)S`Lqi6hnA2*J zx?uaeKKwbP$wq)pxbOKUhjb~V8VMiM!So{Fc3{sVm!jk|C`hT%WZt(Yai z@{VKmA=u_Msx|TT=)%x9Z`!DYwB$ntIuKKa#J2c2a=T8eSj8{)=!6kdppLd4^DsW< ze_)q~l8iC98gMhZkl~XrskQm>DNx;ikN+WQwb(5@| z%?~1<-`hqo=7HmbJE@{g39UhRL_l|^D#uL`ZdIxluc@2>qyS-I?QB#xf9^ivBxva4 zF$^;>WQ)3z-=UqJu>bq&1wSejvN1nzXr$wPHyAZpUA^9Kx2HM!eT%m@Tx!j<~FT5 zvJaI5Z0hBpg%>o(#l>}HyhR9WVx5)B#V@k%5SJL4b>n6YanhqX=b?>yZu#{3qp&!D z_~qf6cAbBZg-WdX_CfX1F8__LYj}vKu);Ox-Q;c)zPjd2fha|0S|}7j3Jx_4TZ&|V zalF92rDN@GWmoab8EUQzqv1@`R?}AWA?R`iuSR<9Z*#|Qjg4`k!)Q3B0ybEhLndPi zwC}>KIiAy!_P*6O$I?xeZ%?Nt4z*pYHorpmiQD1;RT1!|s%y{7}!HTvS zpndR&lUquWO_J?!KBcdlZ*f~q$J|UM`_bTh|J}7?jZfNs-BL9dH zw*~2h=a>(aJMx)y^lw>t2kR%!#qnE?lBA^2*FmvX+(=FFr1s8T{;p54xvGUEp}VGP zZ+c}Vk|#DgM093$+pT6j&zB(Y)FG1fCdAeA(vHEvh_#^-{uUT#j76Jyhis9!Cf@ip zHH{WIKufr0=ls~%U;4JDp%hb6SnlMyvS~&hxY&J#{TI#$MUp&eey|acW|3rE#$~=% zhjR~+V-^0T-}7rH4FoFRP6O`L{hg09n-=+c zd&5(yx^M1+b3V>w3v;>m^-CO%s6PR2NU}aUKirQ z;uR)cb+~{@De)F&Sm*Z)n&q1e)-8UK49zZc4(A>hnRKpD{5Rcl;9};oC{Wf87Glb7+;4~O>LVy*G&{*AA zzkaz2Y#9)3FI|?i;91HuF3TE&`@@N(lbeT&=(S(NDu)NDCI3Z2drrUl4^6gb^ZfZ^ zd2s^n0VbB+ibnFn>JfXe`f@g$h9sv?c^NT1E{WQ47%*-%2p5TeoFF-fs$q7HsWDJlDQtniS;X;q#PRO=gv0QbO4?eV8*hgt? z^{x?G-hYfvopqvpDV0Iv0?n{mb}Ty@ZJdnv85rOq7X5FX4!Fae>7u2|4LB4<8o7 z6%ntu;Xx1iA7;(*V8`QQ@a7TIgRyT>!1je8dlrn71a5m5DVn(PdS?6C87@oo2u+~d zIB73Dm67Z!mT5uudhzG3b$D@0I&TCFHzTP;vru5&JyP;OZYxx~w$5nB2}V$iU!nI% zPL46+xy?y`8b+F)+BmAOoILThUVd~w1>3Z})Zmy}Ze%xu%J)zc9tMRu)CwyqK_093 zGT}9uh75``Sn`r73l30G*PckIxSq(kOz<79SK?a{Fex4v=BV&U9;vNd2)z)lA-FF; zuabB^1ua2q35N3Fh-EwgCJ8A=vdO0nIpF9LLc}}jxo=1ooxT4HdE89Rkh`}A&LV8U zvT3c=^31BC46Ycr+C$3q{IEHkoe0}BG_ai{R3k&`Zh{i1tSc3CoatN=n%?(1`(|ER z;x8kF)mNdMc8_lv|0;}h$@&8>6c{cA7p<#1Z(e9{T~0_MrEy7yA{h~Vs11`}Ixm3) zAH|KH{XZy2F^-kON0}X9u>3p-ti*#9L#OlqyOJK@OLWq`cDEV~x#{Z_`29LYb`n8c z$22e*Ol%wmE@ykfT-0`evq36PuR7{C0n^-wbqd(qsXuT(+=lg>jjY&(!l#&@F?7NV zB^%mMe|&0vK;-0}WX9Lo;m8qx?=uj;#Q=PZ{*H-`b$@$R@$=^gQ4_x_!>gm}>Ixof z2&-*!Tkp%|_CVszAsHzdk9-WJs5RSfDfshiN>K)6aE(cu-@#r~b+#ca%v4y|;1H{> zQHS5-N^mHdxSzNXy{F8u0!isjJ}3BYf3L96ZY1_>z4>F>XUW;|?RSVO+#iEsTbKt`GKQs-jQR1M_ z2E4&Pc@`zeC>uazNS=Q;L%ZWa7FO()v~)cxKY$raXF$<*k4toblwD$&!nZ_u5FI}} zYRM^JkuCnT*Z66lWkqr|t};Ix*+{E&ls?&C;~)_akV_9;DwIhQHYainIZj`#X zYx@^*K^R7yei=rndNQ<^D4R+Y1_J6|BMV_MrwuKir2Gb#0t)~JOERgpedDb&LsROy@(g&X{_9us>hWUHVpB|R+d~KqwCpKB3l<73i75)+JHK9m zcoOI2O~^T|4>*|y#hK{|u53mK(azAX(qiKV0pOSzfwET8)-cieXhJxOFK}|!OWr|| z3NX@eqqFp+FO)xU1f9T>!cN3tg^Pm;+b>73FkZNuo7mb=mueZSI|=-9fGG+pbP300 zi-R$wN3E&1!<5MQ&?}P_E^^!mHqoc!DtmV&my?;@?2Ka_p~2QiO}8u;AWcBC4M^pW zMN0G?u+Ts65uI!#S6rLV(ommqwf9bbaC}PrG9rjyy0#sb^ZsDvtbYa z8yH2Rqe4^JqZ3imcP0^j<)`bX<5kfmA`vJn0Rgs=V>I7{UgY z?aQL}=7hv>djJp!7uyHnB@}>_CLGd2V(pt6X=9hA4eZFP7H;u2(&vJ4ZckE&ObQNq z*ub@X85&O0CZHA2tuQ!zR1rE?&ejoBzYgc%J5L1>e)b2wlUa65deUg}v5~C*t)D8^ z9GiPe4jVVpRGxcY`+gDkUo^cID9oPY4^NfN^*fcrmHd<;{?UQ=eMYgqf-A#>DEs(n zFbcntgE?7%qn`*hg`@h0FLQ${U0y@L{Vt(O7KPbTVInf1!W|8_fCHpkP+6Fl%rYXK z;>-_9F~Uukf>&_Q0&J684{cdzSrO^E?g_k;Od*0xnSp8U!(uZz>|eLN{7IlPPCIJB z4g60a9Z~p(;>s&ntEFK~1ZigFw zm|5%XX)TZ{FP=CnUIGj*-TNoxzTG8of3Yz^EJ{i!&6S*y^V(tZTqGao0}*q8JB<$N z&5ufA7%*CMA+VJ7_O_4~?Nt6K$jQXU&CTa8bGMRRUQRrYfdVjWHMes6?DNSPGo`59 zRKg1(P8F|(g4&KJp)ZGDn|t`DzGcSw71spAoLn7-+)Bqj{^~CA-_0+%3zF0`Y>Q52R~cOFqQ)b?Gc$o$%y+uf+_kD+?6W#;7cdm6aqvi zequ9&QRaGN`8B!X{78Y|r9@Z)7Au&7{JgZ;RTTpb&OGXI-coO1TNIkIch<|xdiu}xt zJ-?}hmj#H<{E=CdJs#}`t}>h9sA=T$BsKjP{w=VG3lfg~?!y#@)TOAIuxea)PxFC9 zy#SOx@-#)uOhkYt*$zs~Se+yxEE+R&v({mb#>t6C0~ zEe5cSDX8G1NVq_o$+Z8;FA~^WH##aSQ!W)|K-)TQDks@cXm~J%9d;Z(>Qxk~a;m5J z$@$UI(b-YZ8xj)210q82=%|jL|LbKz{9@YTBJYE|DrXp*c;I{IOVlM67fDe)gleRS zf(%cEHhbRl?~9kOQu?;GsN_D1&>=Y6xy{zk6y@ed92^*U zeSx~yHF`B7A|c4igwfN}4avx1l2!VL!z8;1HKBwmWZ&qmTCw9XWp3`l5mn%N=d4*1 z-wnaje^AZ2P5Drv@0#(bIF0!8bDbn?**a@qoCbW5eY!V_Tm*FcYGiCRYvv`S0Is<7 zdWRnR3A(Vj1=zB4Gpn3*kYLHfWoA0@-%@dt(E)Lf74b2!Iv#M2^G4>C6J#}ctY6BV zPi>FUFmIE*@XQYTyA%~2+KTl1lj$jww2pQyDzC|;Sl{ORozd}>ZO0qNI~7sBL@>** zhvL)6xwQ;#K!+JWLxPcI`g8tBPJikCbg~Un7-Wc|3mrfjYBpm4cp&F8pn7TmA2p)p zk7E?P$;)|l@Oek~VI)NSH4YDzMwuzWE=uUn;mVX7mPnLzbV&BsxOcXU9hXF!q zY1DLm$#mR3iD4@U0|0()d~hAAV|gmBH(t>0lx0nKH<@m z4?F*8MV%hR18>&!T52ke7d~V)Uu+o}X)VnEfyx6?98qE8;wJJ~pKZO(EuWr{b>A1g zvvqh*Fz4}UY9_v)RGFy7#zjuKK2PWLJeT!+akh{B@#A^Es#YzBGoc-|$yDQZCxvzE zF)c7`i{RHBSFOnN`b*(zw4H*H#jZTr(wEI9_-xvRg`TV19soM3r)5g&+w1)sYT!UG z%EQB>KaILMiq4t()a?H;uFe zE@9fF<(=lK7aY#n!DGgU6p-onLOKLCpyB_)BXK znFYlx2{7Y5ziipl%{&{t|F4_ zX|F!~=~w5`61K?n*l}w7JHP2VeRx-{z@wOQz&S4faT|kJh%N~ai^MH8201-JlHDxn z?erT-a_)SqXekacegcwH-5heuAf%ooS$r(sCWkcnSNQzq*`Lilu`!%CzRK_>-Z?(0Qvnx z_2*R23#2Qr`q}Dg2%V(D?wEGoqA01-p>oEkz18!+hSh3W?EU)CSpStRA%5<)ysW+E zF2D*A>_aJPZ|~ghQ^eN06{vDxt7i^xV_@TJ;OhXbL~QTiAPyB)JJ|SGNLvUIqwWvI!996^ z7x9DkkbFJgdfuS&ySVFGYqN0I#HC1U`z^qmy3zd#*&Yyo53TJx*&8^qUtHf<)z{F` z(a<5m!QpZPx6=lWV!!vowDzGyA2cr&*?|08uNG=**omP6A9{MRXW{>T{u4Rz*4EJf zWfo~sJ|Hdq3iooT!Vz*xg#*v?v+%kvZqxhO<4XS$e{X-lhlV)*HouU(UP@YeTQ#k& zcz$vIrLMj5SB2eHf^=GK>f0)nAK$EaT~XEUWZTY{*9Gxp!?x_q#MG38Nx&og$)u2n56^>JSLJqwk>de z4ihlli&p`OSdQxRQ<};lwS}oZ)iNV}{aE{Sy6=4Iw%D0`m>8%YYs6ml(t_ODMLrg{ z+#2(@rBT#PeR#uKXtHkBI%(Q-T>G5crFq11O`*OoNTjXDkSG4KH0Gy)_u`{xHj>#j zz%OwGr%m{4@EtR5k&GK7FkF&aoCGc?2?e-GM+~+9<$ov$gA6jHFa*WQWPm_MsWpw$ zYGJ1F{O7!D!X)EtSPu!vBWMn*N$E|lpqkX@@Xyv#AexoYC=lJRvU+(mSy2Zxqfd#% zbgRYgpp0=Nr0KhKM>0iz;&{gUGzTQy&BkBB>KP?Br-E&q!SRut8}$wCpPe7~Pl|&9 zgG>qQ6-lZA+z8>>GALIJeWMVKx;`v=1V;V}URV#haSykXdXEy%02+kQM%S{{r7?AD z@7_2O!B2VUwps3M&h|$U)HO-AVr1^%E*Uk=o++m z&VZQG@{)Xy2BCgbas7X+q#W78!mt#=eG^S9{yj~R>C-+xt^L{jT2oy7?+<}qg`=%K zu4aBlX7Wo*;LBm1&)U+aG4%d++Pmo`{#f2|IG26gV6B#AytUH!KH7IvLcf59*7KCI zrL3qFK~qPK?vPL?#4TSu5w`a$!;PPqx-lb3dwCc}Dtio0?~urJajzdiic+f6$|h?4 zqjbxlv@he|S|vI!TupTu=-$FG;z)HWa&m+59WSL(_al|?B;*TK7|Ym|bw3|6N^vTX zd1X4~0IX7xUq(YM!kf+%C;tdC>Q1~)H%VktgDTyi7lsD#(p|m4;iYQFc$#{9&$9Sn@IL$K2 zAnz^S(K|W#d74mAaSbmfy7ibyJKLa^ERPD~*cq}L#ZQ6%5NcJU&kn=skIeQ*6lRw2 zO1`N~TFjz?b=bI1OuWpxj1gdzNYg|X*}6s9X}Ba;2&dJN`db{S7ng!FsWP{%L4)E~HyWKNWJ>r7K8 zizcK*ba)3m5+?(wm*;$VJ{=UcfA@d=`x@76jr8<9lPN69CRqG1KXlLU_pgcLH8z&K z26nyCUw-wkfwzU*`wrH-?fMay2X6+inUs{{dP$Y85{JoVtN+bK=bLQs;LI|I@1pGk zazHB!4-+Rg035orUZ}s_SAY-;UYzGS_V^lE!x=6@s*zS&;kV{G z10ZUxY|tBEd-fe8?6(09c`GcsolbaZ5`y$_hGO-0?0GwE z(@g=4`mVX4#?gGd3f7o`5Pj}AcJwRlIlzJ<_j9nJ{OOcl#*Y)0>$glU>$mQ5n`fLe zpIrLSb|Zum{lCl{YB9`fP7#bHOf^7mf8P*hx2xJFBeh6hXc5&L= z677;mYB;7+Nwjl>Fj88>6OV(l$QBlafU~0@H3+;C^}%;;i;JGCV%jfAKa1JW25@8kbC9w3*Ho-C7M3 z-b(0`HSG9zZXWMZ&IQuKO&%Aw*$b|M`U?rnu)FL#1@@Wv;f#)DY1q-doriq{rAB-t$JT%{#zoO72Lj`pfy64MLKeLFqO;gzuG_Ka{mTTP{CdVK10YYow0QPRIf=)B^{?*x8Y+fc6LVC-qS?q3)N z$papZ@#iZJzB4S2cepnp6uB_2(Jo3dz&?%%c)AhO!rmF!*uG=YZgtc>{7eL`PsurP zYqTDc`k-DERjcm?AhU9Uwi=Ky0W%?c)}x)|E@{YfP7V1(c#cBkUdkkKFgET!OZF6x z!8wtSVAhD|6vZAr((=;p-;c;G?FNqQ^m7NCw5SK4JTZ|1sJC^(7=k?m;PQhL13Dyo7Y}X{PafL1cURUNXAc zxqIYMUAyM3Qwi(*7JI;yDi{74$2OB|>+9B(|M;mLW^V0gjKl*9T3<-l3;)K)DCdI7fUBIWpn{QS8LK+2PFi`FhAH;kKJX zM&DV2(DUV`sOidV`q9y$Zq3uv7KLx%zz9dR@&3i>`=#0bb6#~EAaa5%EwJ-=Vbep9 zXGHD&{VDMc7M6r|J}yt_cHw>t%Fl(4$je1yb&et5k_>B`E!RR2gn{`VEdmV8`7DK7o#zA~A$#iU&ubZ5WS|%r#7{kHXjA+(( zI5}HPkKbC#V&70z)HTE)D8zA@Aol@UqzSrg#CA7`mx{pfkki3e?kVV|a__T%Jlx>l ztkre3ZtL3Xi;5^V5y?nUecj2$W&Zv1fKK@TUep%gTbAyEI;pj;wD$HwmX=EY{2>Kf z=!n~3bux&tiLkNpqUBcqSv&&g5q1W!u+eb$_FOcUA-||X>l04fT_8b1KE6ePmgavv zz3-y2_tWxnkB{F2BTdsKxjf;kZm^UzW)ut+3cH3oXK*KZaN@&#nQ9ByDm()Uc$Id- zpvO>Vhv^LMte8_jirz-GDp~J75n#-v#@ZU z&m{OS(*#eS;UWb32zgz&h-F1|v{NADY3Cn=ffu1jO2+|`7#K#8gZ(>!|3U2fHaOv0kHS8V8(6Bg{g z$-8!_S_kbOJEN@o8tf<%YbCfm62!s#itxS>lH|jHVqA$p z^Et#VC4+TmpAEq^VvkD3Nb_MDqEuwAH!GT5*P5cq}OfY~~XmBDBA}90se0?N=V2D}5CCjCI8oT30 z)2^A${%B*gMvw0ez=RF zT;77Zeb7<0$+NyQsDyv@2WUE?0#FRl1#U4zUK!{_XDfNP4S=Pm_P3 zQ`eFb?(PH4fAfr4S&f=@hO&Pl-d}93Le1FS57q7}T*<+adofw9UvH}h7Pv?4c-N2E zcA@c#vd5JsT0I?Phf1bU*Wz*k(`oizIAz&uD$66WG7-jLxAce_6hR+d4--eh^@Hw3 zYajq2JC)YV?~2wD=qU9KJV$5>!50a8R{;o0_qR9FMONuT2C*L%DNEub$+&anogAmX zEcH5Eu@#_&VMh_eYeMPJoeDx?8QV}sz#$ip`uBN=NU>6s4CBsUR+ht`QCqJ`fE7eG zFNCa|Ata00i%%%XBbp+dw3L7b1&U}J969XCSJNAOSUYF5Ct|uoZB4kL5>@TyW$b2JN3k0@{7aeP zS8F6C4btopYj)sfH_f3g8g};i4sc_iSQ9l1^N;DbS>=I-zCm^}CEA~WT_f7)-huxg zu?Dc#_JBOOe3rSTa^lG|@yu?U5Kexe1KX5On3CMApq% z^OVc6?eMqlI^uTAH>o3Zw*K*KB`gqerl~0kp3W%lCH{T@4}1v|#NrE$S+wSlsIL^uOQy*=BJw5=(-tDoFY` zut~WG{^Mxia9T3IqiP_AiHR8@ULHL~nHCQLn@jr|tPLgJFsf1Iig|9v|B^mAbpT}x zyWRdNK~)!BOhBPCD}aiWFxEYbm-|V4?V%D`tO(fJeWePFo$ zhEX?58)T4$&0ja^Wu|K{e@Y*qOV;xupRdBbXf=`#?=?pB50Yj9b)&5Vy}1v zK(Cx;7dT7BE(K_2B4Mn~Qo?E#OSZg$5 zL|ux6#;tlHiyu5FL1>5f#2$~WaycQQ6a=&DDP(+jNW6OB9|wpFMpt|B%YK7HCUc{% zdZ&i$g>gROu#~yPVNi04$U+V^fP)R{0Ju>#@b$mb;DYkXpZ#<<4ARk;TIukMbVV)G zz1cYP1;>&a8H2+E0_{O%@%jK+E0K#@iSv5JM)m}97*a!qapAF0Z8&(@lz?+Ox-`qW zzgOzANWSW%-rzz>%JBTgcq#Ky2uyXz@t!x*1`QF)H2*}=b<&*@avIBHfH}dxjJlj$HbMk4CN5+zeX7;L$E6XNeyQgAi0q3m{-zw!A5^7I6~uEN2= z=NdJ02WDHuAk`Q4(|-D3{5d{L_vyj3VV8Au&_sKX*o8 z3@LJmy`QSn?>2NB$j4oWT0bkOzD`{+CX|}9v%|!2tp z?Yg0F|5SA$;D_lAT{kn^Zm{_fu)QkalO;M3fq{YL6{ha++1|(2_Tef?9+P6T;p^0) zExI}TD|t!=w5zc@F`e(TYZDhS%9h!*+!cS<(#%VFOqm} zv=f4x-kM@dq9fk2)|?$P+x&hf0xH%VaO3-aZ7EXIy;@tnN?TEG5P|$cSv;D!myp0^ z?2|v3uW+k5Y?F0eCH}w|CaI%4SLu zREN~@jSt7{vI#Bwi9?qrLsWCJ(V=4|?xru+D2AoB%V^o~w%i7y1F$!&72vX$(Y7UL z*C94_Lt%)!*R=X$!V&jM8TQHz$OQ80S%lD^@rghFb(oDX2f%nIQxARUoP;@6GsRA( z+4x&qTf5boot^ec{5kJS*JB>E;MNiNtXg&spg6(0SA2TD_wW+n;}Z_|A-l|vjvyR? z{}b`_fN~Upo za!)yZu=z2w+3Aa1qqg2|bNFk4vChX`nTPndXH(W!r;&Yy8nmHAUFAM}Jv81+^&@=eaoIcPma z-}_1GD(m9S{J!s%%K6D-YJT6v&)SA2drj0pXkop`B|Lz;Bbt>U&l@_LT7v*<+!Brf z)U9rIWhOK5^f@+S|+SRrtjSaQ?{W0;5U^51b%6$5Us?>fj?GX+^KW;a=x3VzAz%DXl-eP3zBA<0bM;qJFV@40 z?8rK?Lj)>LjA^L{b^en|`m~PFGrPCnTgp~vReXkHnqFfNweX0D(fQp9jT<|i2L)&! z63j^wO$;~6jKE1Oq3i?T1R^kpjFd8GKMf>H$3f&-ybEzROH)N7th+_UQ5izV3mpGV zXoR35S)+nVDBVsp*|j&tJl~c44YoE2;SLu{-4G(=>G4bU2P0)ziugg9vg0u|9yC9( zHJv|^!NNuSLkTd_(p@&9oaXR|i6n1c@KUQ#Je0Femr&B@?wO|*VLp)f6cFX&r10}G>ZWgq zFd6Dl`9X@{q#80}&lIi8AmZioMa$&uQIV5`4&JKBBo?!`g?O} ze+IQRJYVlgOA9Tp%@Y@C|J7V`evq>!62{-jy|a|Fc8Q<}_Vx%8XJ;K{3P(;)Df6kv zi{;dq)8MGq>w50(#W>Jr~i}(B(F(^AC7Zy1{Pzt4idrb6E(Mw-HKT)f* zihm(rS1~el;}YTZd)%>^*ILN$aZBR@(lbAj1zsPz;Lx*8_2yi=e6Un9FYL*Cpoul^ zlZJnJi7YFYlb03X<#pa{d!6p`y`+r0G#weTZO59j_vT~{YZF|>W(bYSiv**=0+9C5 zWHC(9%n0m1lNvG>s~h#jIWkQ~kpyOaY(pxM#DfLckyf%|M+5qUBf)!2eF*Rk2#i-Q z;z9(>(21AJolJx0mM@#vbDVi*8MF=k14s3PK+jvxl>Pf8JS{hO#=BZsIG!uP#jjrX z&}$?w|8FxzDk_s!leOh570MB0b`thp+w(u<)~ zU**aPg49xpdo|`@b$dzxLXE$$FyL;4GGHK7z?{fdC)8bLOD$RYsQIY2*MhhAs5!EM z4N>i19k;1oX^{?}&X*Ce3#Asa+O>BC*6TjU78h}15P}f9VZ)rHM`y-+lK^R_hU~EA zCSq0v0jLEL5l@+sDcBe>(NM27T4K__PCxmY_?)Ibh;?`4XlUTzfU2-4{{C&^<1;eK zNzj;sR?XntmEE;GRXufm#2-MpiC7ECLqzQqWXcd15vcdcZJ}=d!IDILGzt>gp<3vTru* zgR!ozmWdV~DG4H}ZILt(e5yax<%^*j6W87jIi}Fim~)MaDkv}u+qOp}^Z7pULn;lC zm@m9ZLFfkxv#iaidloI?)$BjFX)PfNEZ9~6mLpHxEXVYWv}Ms}$3(pP))+cS+nSzt z#ZJ3PkcnQt#k9z4D|v9_ZPSMwr^hQ2=MVB|S$dmgJ$(kAYRrqcdIzVxqe#5k>=Rc*cyJab7apL(PgTAo9XYU zsvh^GirJuiP{WMCFIEg!g>eg{<28Oh$PPzIfY^W{mBw^GlIH92rZ7@{aX;5r5b4 zcFm3Ty8kRMw^UUr3)q^OnemiOp)82QX@dUU++bjw`diZd^M))pQ4e}*rB}@s+nouo zY9Z+0RGvN{D7q2ijt69@ue2cqjY3wI{9N|`Qe%&*W+n<#!i}~$dryiW?Tg6pR}L`< zNn*87T+`4N*${qcbvd3K{y;@wB9fHJZLacl8fp_e96J_CWFWvNpuPMlJY0Tw_^wN& zvPT(S1OtOF&-v+cr{snh+HjD@p$~ujs=$QvvBCDOf=3>ce1*&NdQ$M%K!|#qbn@cu z=KaL~TV<&C=9#R+Qm3$E`2tRLo=d(*5AvI?u>Lw?5K>+oSxM1CDH=5OtPIge%S@rm z{kZRseE+fA_q%WW;cp~HyxY-n#II>Eu!q2p1FHxZvNHe`6C;Pq;cmWZ z2AZj8%ChU#VUWQk$iN}Z=rG`>#|x^R9TL!^Ax&vwl}o@|9fN1pIE_5}X0jUPFWpe$ zMU!nu;i5rUB@=SiQc7RVD30$uA6 zYJmzJXCJ3BDnoP(EPztWRY#ce->HXYswB%eJPb!9cvV^?g01)HpL%5eLnO!p62FAP((n`{e+>WCVX$xR|%LSM)*1ycr}H;_8D36){y zC-?sh3IF6;7(v^ar~hY6JQCC#aFyDfWW#+CiK@dtL99Qh=MP9?$)gjJi!o zDO)NJ3`Uv!~`dy!N__@yu3^DvkRx2tVq*r=CrD zuNUAo0Z&xtpjZe!e~&t%tO*jgBr5?nc`rJ7RJWZp;HD{6h&R4wFrIn|J&25kb8Th4 zB%&)^#ZUF>R;V`mw>E@4srJ6ZeK83UWWU@5XShKcLI`03#d8QHSbZBW6ljHTswCuy zZZ$mixMgf|Tcnf0`7?Q-r6r(LiCQ*&f_`CcAxpqxukY6&YbBwjWv%0&Y!*i#gmXQ6 zaSTbYO3ydeh@o&XXW9AqK==lmCzz~2?{|metN8RD2TDGHw}bg|g)b?eB;74ym-fy} zOxbZ)VJ(0D@bfmjzPy@6eZj%Y_Nz5;)TR}9++MX^@A#D>Sc;yN2*#3iZ0`a#--iLM zJgqEsgMHMLPOi@Ea|_x$jyzOU&c|i!S2HY6>vrw`V9E*#qy2Ccr{ETUUo1D=i*J4| zmW_;H9210>#^96`4Gh%MtM&F$n2qZpcm#F0Efh>s3m?R*kzois#-ea9ZWv{CZKOi@ z03?F23ndD%cz}CIvaj*mLswgI1xFpLz^xz!9?7jGNJ4pHXMN8sNU?SkQ&w^Aat&7T zJO)r&btr98M>#Y$F&#r1(LWOnp8)$)wt0$#7-+8q~vE?sdxBj*}c# zN1+DMHnj+prb3e0r{aF&0TD{QX3sCGl;TtefLKD=O(Hl^!C`9QqC#%!pdtAX2)l9% zbO^rxt3;)R?+VTATvALJCol@QU%OopDcVv)yvd+vpvU1jjxbPDt6YJQBO;E!$Y2*< zXA+b;-c}A{ec!G2lX12U?4)Hk?b*8)eKjH5ygi4(GNF%a!H@H!bfr?naJy8%T8H=P zg>@6>^Do!Mvcl0V9cHNFXRYenUs@6>Y|!t})CULOpc2YUn~!^=39FKH4&$$;n2vh> z1%#(o>wRv1-mNz^Ev1zl)x;_23F;jU^(RD%%V!3;8Z%}zhM6I9fWlPO_}`3Ii-u~# zVCxRLG$|q6#PcX94#v~&cN=|Gy^ed52^YnphtTo8yuG$Vte$9|+-I+?a#NF79U)d+ zFE>9QIzO@vdcC&ctkw5X5Rd|@@u}3RN%*~4$bdfkc~o}OqIm( zyzB4f2A{zg)L^Wlp7O$;9-$M$lT}UP$CFY4{#=7?M>9qrpK~kxVECV1`INB6=H?>P zch7^dq&#lsk4Mu+!L)4Ru@mmQ^_%i?C#^o@EPO(aG@&*>jKL6y5M%*3Kds*oF5DkL zNC#0@93K3uFn{fkcQi<-s6#`E=EwNigzQC!QQW8CA}wY4?i7lzKyfZN4PZ*zttZU6}xZN?#ojrv-lv(9k$~!Ow_7Dd=73q1*LXSp>8F? zaBH%}`1tA$$2%9+TD1cQT{ua#D7i1=R*Mta0-2x`FekvF{OqNlwRk(C-|o<}T@r`r zFb*A#hQ@WRXP^G(ST29-01kGA#{;aFxwB}0;5)CYc88n6!P|3W%J;y<#7sfYgD&7; z+5U=}Zf5&Y`cxdc`a)1KLN9wP;_z@!^u}B&{Nn}jdS_#8o7Z)O3VpUG88rtSm|9*x zucFp#l3Y`SD0>vS_ES+B&R{4fuJC7I2+nIQs9MrVGB{K;_Ar_37lGo1wp4)zcVV2t zfRZvU1WIk;h5e4xL(q~xX$&8Vctm+eHzo=#A2B`Gj5 zvoKImW9jDp@ZmXcf7ueURuY|z6Ziihv$)o22}3M##*_9p8Q0Cp%C*y^>N`EG;dqNnIY_%Mbo00gNid zdtwe0o|}@_U(<7?lMFiC7Cl|-Iq%2NYs8tvskYj!>WY7|1fbx`NSUq#E7D`ezwiS+ zWlUu6FQDPKR91Qm_7Qnf&uWu!rJC`D$@_(&q5)VSh|z4%bb5@8iR8 z1r;9?b8Y0XP^Dh)HH|5uvhd}rYo61z!US`nk~oaB^;Hm=uz1O>61=;QL0r$&2Xjh*0gF8{wZdBHoJs^StC17JB;3lW$%otF&8Zi;)wrgoKpbs&Ax{r&gO z^>(2SCt(<$_l%#)nqD1kJ3k#1kqk>z4TTI<<>anQl?Ks>(pZcR)KLPRMkB8vh|tCUvTB>7}*wTzjR6 zcDy9e0YiO&7c@2Sa5r|Q012mWKO*0*Zkdy(lEnnGasCZEX)PomxpUQVPuJ1jB$^V3 zOZx$*7Z@i2D{=Dr`HBn)ksdkULJ@jVdmL$2Vr2y-0(mO~-h1D?;5O2Kz}q#E);!XW z#0@m?Y>IuNf7V=q({6Iw><&c~@4~*44mw6W4>IW(oC0?E3YtaWQvDycRFj^^NdRSl zou?d*^@AVQA-SxsX9oU4VXo_Y3}3Qj#$w)2>&i$dY~ks8 zIO|TINn#-%-C=e-y`T{P!6*HU7JNr`qj2=o`=M(C{O0efg|`v7q3N>mGVXvoVrQMs{-Yv(f4{j!yF;5R~s_Hb@XMg}07H z-22hj2TNCIeQg15(krxcycasmkRpH3*yPaI1eRwL4$bI(tnK-ldHUF**B=X5IWjWz zCir=(H9gHt&`V6R4V_^mBDK=^c|wQ-AT^C8IcFVnU0VW?dL6d@{rmS{6T7i=O-Vz6Yt zFU;P^aa(pRfP*oEZ5D=P@`1hAsqkHzc26VGIFalY#YQxJGa23#<~*#b#@pWY8ff9K zkd~Gfu#!ZU3L;{{-pHAnS=rm&1-$AzDJxsvgjKEzaJx8dCiNU+V6dsdzwCQOux%X{ z==VAt-j&wLgSh@Zc2#ZV{7^r_hM17uyYAqfsPJYWqBlQ+Q%YwKlM|4d6OZk?W8K#6 z_1#6u)6T2q3&D3W$h#ScGg0UtesB*lwe;^$>rl0MKNZv)~t-&35=To~v=4-S`%H(!tA)|6W8Pl+nzbl#6 zFfm!&mNV0g5utlR1Q>(3tUV;upJ*;CS6jVTn-F2`?RyeE5Koy#Nco!7wNbkc~ZY}V{%P-wM=>8eVr{do}4z&^` z_MA{O9Oy<$4j8iY*Z|E+!$vM>xBwMMp+L&0lZXV;B^XSUTr3}Q0M330xLXBBn5wX4 z98g_&<%oO%PeQ^nHA3bO4;j+J64(xJnJdS#HsnmC`s&k8*bpfbriJt-loIvtPW?t- z2kqcx`rwz2^A98@9pJT=OUkbEy#1V(*2>mT(QdMz0v%V*8N@?RBHKO%z2{Uw3ymOxASHjGdp!c|TG|G=<<}h+26ztqP6yJJAO|dd zsGk|U2NpF@ANDfJZRJKY%)8ljXt+N=Oj*UvEi6O9o@$Y}+&j&>keRaJxA*KBhQZhg z@r1w>2gqF}I5vk$ITmNh5E6VRp@Rs5lf<9Lp%Sb?q{J7Zc@rU_8$lHg13Tlu=??MQ zfJfn^P?UcT5?i@KYnRN;3TRyM5h#%Q=3nmgC1^0;O#O0cgSPBgtIcr@Umi|AVBFZA zm7A{3b||;qpzPJzr6na~blb2=v^N$uHWsD{U>@aP;Nv9x%d#D|e7TBcr6k@)tn?sK zy8xq^yMucNu>|T6sAvcQt+`;~DHv$V*7*4MSeb&7$*udW8M(BQWhXCx5vcWi^ovtl z3W0GW*mbe1FxXBde+mN|(Fpdm<_eI$6~#mE$kz5H11q1euXpk)Zi?y_7ad~5C0d6% zR=?}^OkWa@iH-WIruA-a`{4W6a1_))p`1?Q4W;F2SYO(k$4BHFKwqTpUnKhnoWjkY z;4r{Ded^gM<}WmfMLcA~c0H1)zz_)tssbYkJ1Jx8P|zrU+rRLBXPLC`Ran|&ehO zLw6{rA(Fj9A>Oq=0U=K0rHnj&?QkNR*IijXk4i%Y@36vK+n$MXK|h4}9DvuWtgRxL zGz4wHM(3C87iS9ndZO^*8_KEp52TmpgoWKWcswTEc4ddsr?ou-QP|+%4oH72YS|C- zbzX2XO*XqPMui}fXkr3bER!flD(JB@)lKpGSn2iP-pdyfDw@S>PgG?v#t2GcW^o*Q zY9rR-Oa%!7Fo(PvI zKlSW`QQ>8gnnj<2y5;rFevzQUH)h;`3i5==TwjqpZ|pOJCdgoVXf`P&byGg(kK^Hq z{TZrO`*=89Ioo3nS+9HP_Tt-V&@Q=uDXA>3ztev0vX$8=38p_RLXq!#h4j6f6a*k5 zYt*{Rg_nK9s?HkabhwB8kxAbj*34!E;KDPT+J?_=RGmX^)y%vq_V23CpXT; z(aw2vm2{9Nudg1#h#>E`h(lc5Y~Xt5xt&t})As@uM#k|muny<5Bb(Fy_jKEEcW%t7 z-bmHP#{KaI;o>D5XMIn0cwRom;rDNU*!UU0av`8~79Lht_?k(3%kD<%Y4cZoU7g&t zPd1b_8JF+)oG8c(u&ASfgA4C;twDYe98b(`FNgiG=WwS9xEDvenkk}4_cOke3}ssJel4r(HYt@ei+m|lL__E=}lE3oWqN7J?S8|g!$#&mwPeKcP93nAg- z`Kqed_CF)yzmKd&FKtgj*w4A>cC3SY;i7dBxH^ojBUj z-1iGN9du(vY(DC$k=OR)D+$SChqCMv+PZZRpYE5baQHFM7?JvvQt|%#R=(Ln^SlUK z*kmLRNh}AV@N}kUi`x?*r&a|~F|2!}OaT6kpu4IjrGm3>zGQtuaq{*OPsOoN+qPYb zlWMm+9X~{$<`1I1apQ^h)CX~6A#J_Sbv_NpK3uLLUQUOk_S(H0^OEp-AMdQMHl|c* z@NzH$Y&s@L5)-rH5hqLAoE}j0Vz65zQS`g`DeR=5f*s9b&o>`?FW>3g*a;{8;xgzV z%BKX;Qun;?&YF%tU!@A_eQvbcK^|O1p)^(U()bFQs;0E{NO;Rj%4f4!5A@HPZ3lC> z+9nmo5OF))m8`V|oUN@rZFI8R6G#!3VG$Ap!2m8(QN9*v!4YHG+ngQ{J|o&HiNayU z##Mm@8+`}K%!~6y{~S{634@2*uI9m+g~%+ z;mXBE9JjU>m;QC$4dutB@4p@piThff$P|Is`wP615SG$`i zFL3YQdv^oSal4CbCI#NCgh@w)u9CzO^Lan)^>4nv9GXpJ6QN_f9UcS%d^-rtS66sH z20uS6wH>~Lh&b39Dtuj~?d}l4nv$=Q&*jRpY=E{Hg($o&Z%joj%uiM3nU;(K#(Er*7ApWbea>8~m zB&Y40Q-<8(qRLmk?70>Phyy07;yCFyo|+lJL`UD=-amFL-SNrnwV4?NJhI+=Q=`1y zx;jTOlNe%-xofYbsi~o5R=i?LK&L}tK}%Ixem&|%Ut~lCK{kAB0uhhV0La$&b~EV9 z6@o<5=;|b+{(SABnM^kHrUW9yqIKuZ3|?)WqfwqwtynCiujNP>W7~%(@*yLI&)^~vyYi0K97G+;G-%61|8N&s#S3yYVL^H+x?F97 zNRWdM>|K~K?%X(Bj~W55l7W{m^Yq0r*>R~9x!}-?wW#LXEoaW+7-H@ktHsgk&fGb> z$AjowGr3T35NhVi(@4!&1@kU4GN^vH;Ao8f+a!i>^j zY%mNwqxZwIyL8GEJRByRZuWc6X|?#^V77oq(_PPVekfp|zrV21tcxmrRJD4TK+qq+ zr4kfLfmf|bj8Eis*5w1XO)M+X^iq&3&a&#~I$oEtCla8WMW#kR5Edd#uN_fIrQPq-L|95zBQ_ZhiMgIz+IILvI z@2jty$Pa53;<<=0q{sV9-Y|A{HqLp@9(Kd+_-#AS-$pvnYanb~0*o&QIid!{)|vP9 z&NZXpEo#@gdPL2mY4{QQFEd06G!-b6%S+LLb9(8*?0RBK{L2ez8wj-W*oCUzaff=v z+jRJ}&I|b1v+ZCpSb6g6@@rUV07ngml}z;v@3OGWb_vwnK1I_vsG^~TLF zO^sMt`}n?goQHNi!BtJMHl-oHX(q7|<}nyk?p*kFls{$F%nKZZqyc0PvxGmAKzELhamqMA*J*Ir zJHXZ9ue;)w#*O|8MaifP?dzldZEZ>ZLtKLCC-qDCnj$86U_$T3WBdg#60pb9kPlDn zX3(UgKYXAkg@0%(C=4~N_1HIKZOb?7WD;L7!3zq08uSvG?Ji&OluUceOW=P!Z$v^i zyUXzY?8Q7>Vt&4geVd-XzC8dQTRYTac^5NdOSmcVO9dc!&*u2k#A7-H z40fqq|5*&2gM*?eyT;&6ptO1r zVngCYZ+pKZI$BgpsZZ8wK-G1}DG$O%;W4g#rytE^%FXTgIO=|98Z4)$7ZMVpyGYFz zbhVURFB=*%?|CE;4&B?^ANMvL1}_%dg*{#6(*<;3Pza!qsS2r5CqP^z7zFbR3m$J* zKlLxntiPgpcp8twYHUdJI%Qm?n!ya)2)lcsm#`Zg=T_*6X?1kDMI~_;}NYXKd>D?vC8E+qNaK+J9e*f9aOr z2u{8zoMe=ja^ZJ*ln#Jead8O>D%V%N z*=@+a`s&vwWl49+@WH!y_#c~Qy(5$Jk!l;?nBCCp;#ew^p5ucBpg+y6E1 zX6tt-uuNy9{EiQ-BFdV69AT-LQCM7BnaRpxHxw}~+ePerw_2((#K#~;6RJ*KySsgq zm6P@J{I~r#A#Q$c?O^C!wC7#Ivw@f(E35Xd8I~YlJZEq2PDI`P?J2X89Z-&?v|BTE zlA7vswNQN7c`s?xjXxDzH+wu++u0cwnhg{mrh+KlT8@vYMn7X{R2JgkY4z}2BxOoM z=dke&+423FD_f1=xt9=ZfU_KX=YUQ1?Rt)#MI!$!H*P{9Wvk37yjDmzbJjs?aNzx7 z$GvLipeG5!v$#T*I+Y))Q$`4a1S+Eh>_@dzoMjL;Wi578@x%UepdA(<^|7D)ellD1 zhgqa#{1R2Y2L==DD1Ks=#$_ z64uAUAs|AOn-pNlNDnQCDwtwB6iPbXYvB0RSgaqgzhRAtC4!*gWXB{egxTZwGaMb0 zD>y%1_XJ1J53AM zyZ!CTefHn z2l49Q-afLgqNlnJ;$&DwS;>*P_2$(Fd~2i4&xDa12c{7z>B%e}(b0JA&KWrw87b0q z#8eNe;=*X6WGVo=+ZK!4VDc3f5+UK~{&6GA);=;?R(>;F>JqFW#y%~$X3)o+*P0!Y)juEmL|CV-~P4A`Irb?^JN{3GP^F@o}&G5bD^UeLD z0Fn9H{;J@iEMkPBH8u&v4fxb&VBsK}VC=H52-G<~suG4s6WF{ColQC5Q15#?Dn2#b z*%^=-3J;4gx3T1L-x3*KPFN^fz^SeZ!I?xZo)m0*tg$;V){Y4}rF!&gxgD54RXUKH zWWWzibx#4c*1XQ1{|&#sB2PLkRC~1EfE~@RlpX&@V8IRBEbNJ=T%D@B+GycAowt1E zB*F`-`TNu`y8}n@Blro$9SX`#r?ym>HslWw(l3cQX~T>w+kUP9tVyOic1%`$pRq~gQikl=>}r?qFmqZ2fa_Ou49e9KvY zni{Xbx8B4oU!#5Bw_f?!Gv2>%&A~@rEP~tJ@qiPXmBo`+7wVX7i7gM_S`@;X$HZk5 zoQ}&T7}jT!Ke{Pf$J))ImLytV7P*f^gD6n0JO$;~Fpqy;W)%qWsBeEn$kq&Bs+yg5 z_f=O=5%D3&yg@;{v4nJ)L;RS?i-S<*i0?Kb%ZkL&DT-sF!pxHr#Hp#!kKkSbf>vwFNX|{Mvq_IE~B{lKUn}!tAU#G ziUys0CCi-3(7v8n9xP$9D~n03m3Am(Oo-?ZF~2SZ4{4(G!8Hg)jcWU!Sj?Kjz!`93 zHm29OC$Wl!@6o|xl8~>$LF;#9NmzMZRQWD{x9x_FN|lzDC*@6vOJbv@S&UuhAtu2< zYmDjfUMr3z&g64`>0)O&b7Qr1dT8sHmI7jvZZx_gtnr_CQMTCMTV5Y0{`&hjUo1L= z=Jf%!XDT;-nz_cy@pd3O(rog(=3hRR9*jJj8Fj>$pAc{E7WjPMPz64&)Tqzl;kfzu zSZEkA(fFuR!Pps}wzmy_hCfikp)kQg#!({y(%(kG<{Qs3`ccxZQ55 zJg)bGPa^2t$gUK!bN5{@ZQULd6JKMy!L0O|1gu%i?Nm(-NeweIwUIF8CIz(eAIIPy zDAA@TB^5sd3pGee>&s1=I`xUZm*+848*QVrTOBF>UO1VjpGf%L9*5y?CN#Mv&2~pL znWjzB+AU|cUawc1UU-_fdTMg=e;pl&93{}zPho2ONjjakq@Pc?8)*Q?d#%k^QiA+| zLtxhRRd??${N&q_-bRdMn7&F?OK5L&pK4t*#TakQH1P|j|8n|9NH6qG8H>JO3@gw zvKARJs-mP-Y4h!BI!+wqKE-b5U?X@T)D^q`+w`kJ7kbbz$p&%_Y#TbxpMnt+TuDs) z5=oRHivdRzD$L~eZ5v>R5QA)EqT9;{kfz%))>rtG93D9f;7LGepQKBprPKLqN ztOEX9R=C$%@X{i4hPw#9JMIuk?%+)dHbk2w;WD_Co#QoXxIMOA?hs7wjKcH(O3@Gs zLHw_j9}>Tf(YuUH(<8X}S-%gUT0W#YyK<(oLk6}PR>leHDT9+W`sm9bjGh!zzn_n<=l^tHm!YKxE0vvznloke>F8BB<2-mVQo z1eX9EOx{v4XN{Ek@2O2R_BV%AlkR%DAiU3SjN??k5m~E;g$i-UDNyomYjmKIf0r(rwV%9d9S2E8R*z~z<=iDSBBy?%{o+BUIDlwC z=X;~=Qq(*3Z}FYdLGedY;HjP}M-GuIi1MJY5!@{$Bvci6cuV4*EvrhOCiFE&XC#4M z#=}^o252VqUg)y6C`zsVoy&{G1Y~$xN9>fyFdMkY59aK$xg??M5w2`AG+h^uZw{yE zA88M~!=*6jFn1QT;X}Mt!n@M<(~)W5f)dEFEGa(S#F ztT4+|m>(R)^-*2HQ^1f>TsP3?#f6MG8O?Io1>_a{J2J{{xs~@htk!eCxaVG!HN#B( znNJgnww^4jmNyvNfso!lfCA!a_$-i)MkOH75RMQ(=WRCmEY)F1Z9Thk=`0ZkgbKYR zr=1tX#CHMJpDTObf>cQyJnqDy+$vgZ)p>Cwtsdy(ig|Tzlk`_4_Fa)>{z^B@Ig6oG#Kus0yrSGN;b{?WF#Mrp z7$!y5OrSAD`RIxH!Yi1FseVd#pIBm!i6U80l1DZkXCj}E5x_>j!}3SCRnFGUzM8;L z{%a?x*e^S}OjSW%jlgtTVG=@R<#;ZY`#2Fmh>(4HdR!|(RfHqInD$E3i_X*DNghS% zdB>OZ#85C!RaekX`Sp<+2h1ChW*89JF}S5oFXX0El1ZCS6p2VaEMs+m;dz`O&;1MW z@hwVQ@?T9@pC~MfFH`T;nfr7V7~HA%Q<$U1xT*(h_@mv^zVyDUTlfI{o2a)8=@Z@p z_r&yku7kZ>AAc{`ko_fIevtAyaV>#F1O^NkRdZJ9HM;M%a&eknO&MVG`RrY#$Yxxv z;{X!VIej6k2+t+t-$-WGR={7~3WJ-+*sAR~h#!xHNVfPs zH+QgGm&RtF^|X-g_mHkRVy?<}hv#o-DXsW3NqS!b`zwnhfK2Us91u_ z=qK#xzV4m?$g)`+>Ew3$T{d18H%XYh>~7Z+(K$xdGJILb77v7ybqrf7ziXrC!S4ZgkWbH>bZuArE2{ z7%VZvBO{pU7L_9ze#gx%>Y7{F3=rCaB||J(kvfi5aIwQSkKmXKw`0r(Snl@b!e+Hm zla@ptZ%;xtr>;!E+M4_25Sm@5f8mqug7Yl+jB}Nf;=Oota9kdMM#xObTKA5sda+r5 zy_>FcUfr|fb@n@}S0heGex|QC=W9+$g?iR>I0$UO9n)evRTK0v-OBM{7?%;QNV%-E zyi-;}BMvR76E;otbvHOP#g)a^b}l-2nm&0SuN+qj)nmp8H>972_i~x!xAHbYlXOi#uB0;Fh)uZlTCQKXSwd-joe=sS&$xia4gG% zaPNHe`B5_oQOUHjj6mzv64|`4Xcn-Kq^pz5`uEe&;vE`%Z5+k_0j$& z@=HZ&`_lzc7iW-lic={%t1+ZzOQ9!Ack02hmfRf-r0l*yCC=8?THU~x3Dt%!I?z$2&l?>VYT}G%VaePffXV=7$2dgVt`%)=awvQCQDQdG| zvM-UO0rVN@jr&ICmM-lwTR0$%I^Gei{?#3s8baW&1w0(%;OMkj`_Zp%vWz+afR>m|^BH^iig;sjNbGpyP{i#6< zn0qTvF|m|D#nn9DmR|g7$CJPnBO!(vE{9Qyu!p+`Y(Eci)zxZz3CE*RmF*6^AFKY! z%VEP?`-zY>;{Ayy=yQILe6x}zmKUKWm%;Y=@;oVU!v=gO;c6T06}x&h)RASdxlK*X z$e+;dlclVyt+SlS925rCq^m0{dg(8IoUk6+y#{;pKA;hBH#w~A(fGcVLwi0fW>pI| zxqd#^&EKZzMy^&lIG+^)Til0R)Hb}Hk5lbjjP2cDD~$-YIGKq5rWk&&>AZu-*$)%~ zidqa};yTzFKr{2wSgm^U7EgfuGOczt>%2)D-sp;>CnqKXiVF(1hT;>SdQmRxGr1gY z6qSq&Bz-)@riy7M9M1%syhQr#@KNtkn*z&XxfZY9dCJg%O>ZOP`{tgx_tin z-D<1dND#XXq-W@%TD9goCvY>I3CwOkeH3^umjARgv6hW47O7BJyDrXLI_uyLmAhPb zHPn*rC1^8>1!kuur>_zYrge<*YesU%?)~bKQD<`|f_YNu=UUEHr<=y|mK17p1_^u4 z7;)866op253PnBNvrOfyOEcp>^>Sn1!G&&eJhx-;6cS6jJh|x3Q1#pgr6PzcI&bxQ zegUS%+sA(Qk!Qsk?u7tubIS0!V|OCn&JIKe=&4YKbQ>Ytezsdo|Gsm#c+;54lF;wN z7MQ74#IJsPD6`g07h)>xPF&4IpO?+U#Y`NENE+GA8LIRtPn2pB+H;MkPzx^5-*cGj zpL9bEVa1#bA8GP(UwX>xSU%?W&MB=mzj+PegDA8va^u^6Nd#_}`ml>opx*`41qb#! z4u>SKkrpPQqa>E=Fg%3f57l1;bjOO1=%1oe8d1v-aTrqmmE@cF17b26ag-P6a2~Uh8@NhyejBph)WVG>`G=6Gwc(y3Eyrj>iU>Ug#4PWk)+xK8B zqq;-YvEz#wyXpNn41VlA?RPrOj-3Tw0T0G|U~cB-%?hPXlS1cl`y<+=FDVU{$-!R1 z{e^dKhd{gSriBXY#naNeQdwCEjD;B4Oqa6;4Iha+3@aq43MDZcLa|Nir_zVbwzwj& zJy^2fpRCV5c#!}EkUC8bW-GA}3XivbL=~kc zBfPD?fvJ7nV~b-iCKloBe{4kgO`czwn(B<7b9~$!Umj>J2J*r;P>v~`^7QGK~q^D#E0PM zJ3FLs-#a+GH_If@1&;nTFnJLF28*BNO^?V+BS*kle~f(FRUf#oKIixh4XIN)sj_hQ#!Gqr z9gir46FEO4c9BotQ4mzek!d7JhcA#cOKh3(Nt?tf_NY|ArRF%@A5vnj-$zx*6;kUZgbAA$( z%oCapLaquTs7eSZhqp0uNzU9YLti`@jSf3si?9)W-k9VG(?qxy7ciaj*n*#C*Jd`Kc2yKT4^b1YMJy1+sa$q@85wy0#+Zd z&WTpBhDR>44*AQn=Eb*a!B1aC-4-@4Y#ZCuX1nM6eyYsHAQoQ>jpe5BbI*9ks`d-{ z&lB5F6jI$Dz<@=siK!`2y`iF{t>Wb_JtPuU7LRe8xvIIj*-c=wN$TuhOcxqO1XROW z5e_sQYE>FrIM5WtVfFC&lqksWzT@WY?QPU6`gRhyK9+EEGi@{&O~T93nez^!AVKF) zzzwQJ%%B$KX~x4An*Cy;J`>_)8%-I1skg52V86LeDNs;#l9Y~l+MX67YTxrsMHEd3 z3IgdmwXKh*>oW>33#s-lW}k4Oqy1zUk7m42E^lC~Dkjoaqv54AnV!_!XX$8r^s*&f zF7Eg0rSuzz9Af@JtVWIDG{xrgH9_i zJ`e5U*{}1`$LZcylk;0>xvqjvqfzd&n3JI6d8_A4X|y)qy~1KWtxgvu7aLV?Ykrbu zpG9a+6&a$$rc%+>jA?&<_>T8`Qd5&gy#RjagT2WQK~lZ);u>W?48nGBUF_mkxTsFXqpzhBG}??;nX zm=t~0dFVH0;*Vf$?#Q4t%`M<(qc%0`MSRakB3YOWN!Gr3ieI=Z$gHi34BANDYYlyt z+aHPEFUI5-u~7-o9LCufp|3oSS}kpIt0Q-x;+5=gIIHhzDrq?`&X1DVv&WMdMNUSg zETe!WD5&6Q zEuhMH8aw=_zc$aeU_P#5hpo!SUg5N1rE+R#kod41@F)48s$7@vnK1`_gbE1JJjDvx3dd##<=H>vnJ3l~ijwNun(2fVtq~(4^kSv8hLqtJ^l(`E z_gg|f+EwA}M4;gz!f!uz0?faY{zls=#0-u3(qY~|L;4AhBh`F=X;v~*#`b(yI#5o& zzkHO}dmWJfyzd-=OB&bl%5+PT^^CvkumFFbEjHsV^YWc+)^a!>_z>?oJ-0t1A|g+% zfLlq83kM4VZb9~gr+d?h_3ZV*{i1gw)MuUtID4dPju5H84{>m}uGzT~s|=-Jv#Q(g zGb=+MNr^Zt-7K2-86ibYf;oNm4j4{3Dt%7+8*Gv8+H>+{Ry!Spi=}fsqr13!&SuR^ zOulB$Tx6Bdr&hM?(s`Zlq!$*7t~Q=k3Fx%nq%XSKtUewyZn|F5Tzn^0mnt{gTguIT zxp{ZoT(`e@Iv$W7Uy3K3{u(c(Wmye%!w-f^?6$3QLM;;)Q5wLKju|f!4-4h;}&9eY7fMdYhrkJ9_ z=2QbJ0n`&X$tW`^8nv%m@v`Z8zws1V%9AeGgjJl|^h!2&k7R>DGZ!%h^y%d-Sh_sN zB6MVfQY+g~GL&vfbZoLTm{c-`1(;TO8f;U-7grKtim!6;&%>{=F?%s zw#c?5!$*INq)$=t0vWJE@T>Az79)O@Rp3h+u0(Xhhei<%>sKP`C6wsGRG<~Zx ze(a{mh~h*ASCgS3>1kLhgv0(-G5KHOf@3h2ifd0fO(wqX4kx??H!q8w&L1O@!Pw*2 zyOM~1f8y_RGc8_BVS|6iS+9SOJQIMVRMAK3t#FntJ~n{kn6VEw@YrD^Jz~Bt3dQ`J zk)D#7^)0JN)`uQ(F#xhh0%aTVn`5idt>*bOyO@f{Kaq-``cx-j$TS8aSJ+}ZJtqNq z0gh&qB)m@8vBU-5!l0lg-%B4pHtNojuUdT7Aq@<1wOa%IO@Vt>lhxcld-ygtL%)r; zD0iAitN3d_r(!>!1$`gRNq7CFvC64S2!BFxyFJBGH3Hv1I6LYNmQyDptjB+04enqM zF`RNvfCP~iuYq-K@d6rINt$8w;KWuG_7N(ZBpztdYwmfeh2_7~$* zV~tE5w?M)&qEQ7HNi`V3dua?Y4i!;qW>|5>W(c>bg}rnv>T!^tBVbUz{Has#`M|$j z{0r9V{`uyXo$TGx$gd{=;Gczs&G67!|UKHBDXv=K+vXSR93jXGb+${xq@S zECqJPXl0EZEu91{9d0rJrL@Yvgsphu($pmEw?ssrM8R+zDD;G! zb0Mv-5@(~eX>54Axz69nM%4h0HAx6i*`8G*Iu`fE3I>1dSkg76H0`HL;#1)u)%>18 zPg96jYQFYLyCR>Yt?l%4X{xfUMTJt#X z_o2?d=iYOwYS-Qd{jC=-==G)46^Y_-UkSqiK!BN@yRG^4eX$^QEb)4q-6vG}g9*gw z7MW3WbGl$M9cdwzh!g<#JThC#4CIVSjXm&$=29+3u5`FJ$Poh+&5%a*KlDd;Z@kq1 zL;u;cQc|S&zoX~MwkEFrvyP0TXSqiirDh65zI{BqHjX_JPzwnRO{YS8Rt*SU*EipvPYbaoJFNvwLXI{Jw{ z4uubWhnG?c`|aGp0`PTbP*+zcB!)P+zt;C)9i-H*1ug&Uf{A>XeMyPf^55fitKRQB z-wxEIlrV`o^;oOEcR2XpVceJVIc;K-Wxw`e`KqYD3jdqib?hcsx zxm=3cb)x;cHIcfsvm$|UauQcO}O8h2dUe7eHNRho8{$yP=ajQCX&~Pdae3BcB5&+`cq+aKj{x) z(?@=O?=x=1Zi=XC@iM=FC2UI`vG5*9DV-{pClkwVMZk0!KCJXok+R!xEbU zCw?JdR;*m&;XyB{d0~yiOE@^(nV1j9Si+_(-l3xLSOKA%RiIl#yo+S}!#3A4X?WADT-5>7{%X(RAOE z0ZI3Fy|tni%vU|_eO|d(Z`q|c2Q4Hbh+xr2rHeYLKk6$e=`$vd9&TL*>6OZhECj=f z;$oix5NXim>$fWu8I5oPzpkOP8dZ4B4O9N^d^a(rMABs8#zj}cxgH5v{3Q|4aK15Y ztV2jfB^ngDy2SyT!Dth_BL6RrV1g!&YD}rDPYputudp1iw}e_~i87{$uwOhk8610t z1T##{$2iTV>0~N<4_mW7@oz{}g&IUkUuuOf<0=+A?~*)2m9Ip-UY$~Bc^1|ymHoDa z`~En7PLM>UNDFj7!gIS9zH33cKtAbfg1=TPcR$xgJE-uwjk>=-`LJU4he6|4S6Gxg z@a1psIr;b3-H~s-A**jwYsjSr4Wmd2tZh3XKhEv`YoRsX2BPtMVGyyOtWOHxEMkFc zpy|qdr5+?RdtuQ5MQJLF&F88kyTiV(Ij%2&lP)xxH119Hr>Emq_21rqsiVQv2zd+x zc3RhNMlzRv-5)G>I8=C5n&lkz<~uI-`=L!Z0mogwPpmN3TH>0W?- z7DSeXFid2pc1v}XMr|skK8vd=ZG($c`Q+_JGNNB=79K{cII|H<@!OxlpGy5>g1-kA zr#o}}2)u*c*8?VOnwz1ze1v9E;i`ue3Ss9p#7Zs%wfcCHMmXo$qar$Rzzo(G^qmh; z(i2~^wAh?8gbtN0`^#We6c_=QL0>8kS&C$&gFS3PGyOZ4>C}YsyOvJ;m@;CLkg49f zFZ{hfO0&fE>w->xh5r;qwfd3fgAUleL@N3(CVPCS^``6>6i|J5%~dF;?o5@y_<#T0 z5wjd@Jk$)ui2~24^!`KJzD#@Pyzs_?*|S7zCw{9GbrRR$ug9pvdW&FKFxk;Z+@}&k z$&mP>Vn2Z2p(fBAqAJ=a?ZiS5PZ!}t!S8cO+Ns;GFIP*l1VG=v+#aT=)?dV8`la{)$gjuq; zhph~7uz>pS&rItg4uxPesk08f>C>z6FHXCpVInX2F?;ZA153JTm}}$;!jK?hN`a#( zocU$%9xCiPDtjCVOKSMy8v8?0vpC0BLTjMsI*};M-GLeYr~L#y+twkNWXIMsEqyTs zMe}cRHV6I7K0l9j`zSJphZXHg=rXesC8aC(rpdQ07_d@)(%yR1BHNnn#O6Bzh&}}3 z2ua|}@TpR(LPlg@SqN>QB~3*D!So(4iar@=2qEHzcF+W~;gDS})Nrg^iwwMwaPD4R z0X8m-?F&_>nrPOh;Vd1bV87IO1!U0Rk2(j7{`@i5J7)Ue`O9BT^r^?NVPI+(n+Trm z71f-Fo_g0KZdZ?mO6Aj@xYN&R6cj?@c-kb=9vF1K0y3pf{uQ73s7myt2fBWXE2&7} zM5~9*<~ynSp$#m{!xX6v1-m>dh$c(nJ9SgM4qv8-{rEXH4It!XObVs>NMCE<_r-0J?GLzV)kA-4qS8znJx|2`OeG_vL!mB99L*!S?3;KR zwX$_Pc)P$nUXgrpb0_S>!O3|covPTg-m%Y=wXk$j%}}JdEmHrBwv*0&E1a6NyRWl( z)O603^O78TuH&x;N}?yhEF$zjM%W$3Bha29rhLiP2$PED;^w(OP4^EyN-?luD2y8P z4WBy7;B%&UQPHJM7hjPr%9BgKfJ6$Lf%3)z#YL4l%KCBv2JCu2+$`X$ z>4XIRvXobI=S79wrQumRn>*Mp_j68JD?YW8pU@NVs;JNtk=)L-*6M)^PouxZL(>f= z&vIJeD0k^ap1~H7oXiww=9`c!M(~ji&P(`~_4N^QzBfEYBPAt4ldHCFCes?w7VrBf z5P(JnQFy^BPxkaXy_PAbQxhYOA#-=>mhCmkxkOm-5MsJE zU0*LpN~jR5459WvMrW0ak?J=LL+5q}P2;rkpBvUzSK(l6r|qK}u~d@rK0S8&O|RuL!|j_vgkqS_9XI%Jx>DoTdh{@3>eJ#w+iYRj6L?kdlmYj3=VH}`8hGxqwsh*#= zGDli{W@NxvJ>m6><6+VNy@ukWQX*tE?Jb4<&>R6`MT$6ix=hn_YL8+CwxVWHG7;CH zgl5~sv-ne$f3mQ3BdQ7@fm*9Irt8iq9W_%xemY?pq^>x3Ow~i_mrO!8eYk8lMf%u| zs3cr@Sow(jL2T+jR2dus$y#b*$z@?#Jw}B`keX^iA7UpkjjBj{`GzgabFmlc*H}Cc z<<92@3)a&W@76o6v-A_M;qz}VdlJ;roqZn$0DGF~YxX!I9cB=GZwKxX8T_=rGQ8h}#W*zG3h*NlSs)uvV=k~dChWkic0Sk){% z?wmsdd>thg*PK#eJ5o}tK0Ct&f@NAF{HkcjQ3PMdi6VAcY>sTS8;%M9ptD+W%uXW= z+@urcIz9IGKg4Y+$qdR^d`Vnoz8sYkuuei=Op>YnoN}2PK8=|p_Njm^Wgd~2GOF1F zTUB;*I-I*MKPk0_%5{2XQ{oaBE&@h`AlkfNQ7k=i-uxQJK&>#WC=XSa4tNrO?= zHlOKB0rJFr)+O#ctD6WgHG$!88Fs-Lo1jAP&r1UY6*k5)F(Rt(()Qw+JSm{FH^0oc z%J6k($45gZzLsX4;SF>%eJo}UV^64~KgDMF>>CzY24Rp5ktU|!&SZyE^wI^i>B#r3 z7`>cDU@zer!xU9_2b)$RRf%Ry1dYH5D~HLR{G};T_4y{jkE{V}?u#E5u{2VcvHvX} zrn1l#LyiYA7LqPc*L@=2qYNx161B9hGtUGn*BynnvCi}8Wj^i&!xU!o!(NCJ zCE-v8H>gS{##ge-?D}zXBz>z-5WP+m2hc!E4J8M9Nu=?_bXUe}7q#7wOOIg|U_>&~ zQ?b97jEuEfgi9`lKQWb9sGNm7@V8wm%g{GJBo_CTjh{t+?OTciTk{W2yO|A3{N7(l z%zXiXUZg%pvm+n_kAC-!fv$iGpoB%Bhs_Wb?ATGl;aVI9nX|XDilb}{YA{`bwVJ)j zQrI|VLy#qDJj6)$!%NtAmv*mGI&PF+0OR7(E*>G7pH%voG4yOQSmrEJ#H&%)l*W#H8oWwGG zJWtq{Rdp_x9Vtv@%mb%u!t@ri@QR^B;I%c13?aQNt2u@I3S90^Ect^>v~f8gwNS`Y z#q#2C2m=d|`mgk$t-^*;4celv3mD+C>8IYs*cWIi>jDDaOYrD{wvviQPIC&2(r5Fs ziX+}iw`709^TkDFVXp?)T{9XHG&SV#w3HLh(AVYvpctm)U9cp_c#9RML5F>2SN%|i znqahFTP@wIj&A{8K2(Uu-jnfu8#*f%`=avqfimA5BDFF^MG5u`>M%!tvg~o_{fTXIL8s^Ba5R%{pjK|J( zdQ+sP4I66fT_r_4`y7F#(xc6Ck-EeE%##iDE!^4;HeZgGKO0Kao;Gh7i5(=pu!O^o z9|8glMI^1iEK`|QXO$sFeZ6tZ=RYIwsgc{xTV6&Z@vk(A!BVs_)}tkG`Wug~$)&-s z83QMyB}1hoJBwHv4%_=I+r{gQtz*>wwseuu$l-C}4^TQ(2T_qJ0hESL5V{Iyj$oA) zq7Q3m;&EHlO9iROz?#2~TVZzZZzNFeu2WR_&;9ep>1zt7nb!U`OwrD_5LWt$Pzx#i zyOF1==u~$NL2I$JT5HaAG`;xw3o;uPCF2M)863811I|za(gn;yfvh4DDLlSZ(XW`H+l9ReK)kA}i5m_%gvXdhO z7@&tQ(}F)tdJf5&_ca}!?aw2j;?fjyDZlvXhhAV*m;htx>oJH}$mj$|aIp2A|` zuT$ywo8n^YG@V!b4ODxxKwLJ6R7V3H!xGp1Y(7%sGTq4oH?s->YqTooRT0{xgqegn zM8&PgkcZHcL*Dr3kVb>L8opWPZs(PCBAK1>Ua_kLBPxuuN;iGCh99(jx);(}&on4a zsM_A=VXH`s&RM-RoR|Zzv>1&7yJz1uos_|{65b@x4Wm*S8?mv6wFOI4O(B_riL1$f zm+lNWSdMg}Ahu3gBDZ@e`!7yV`ARf8j7Q=xgJ|T30MEBuvwLVsVRfH&Ou;-~Ng_6E`xkyPa z_~y)s<>ZzwCDQ+`Rc)y)2lG_`HTHSR-;5{b!o1(V@Q=}Kjf+q81O5jzPs^G+O6*$f4HVqZ7%Sl#GDHKJL?t)Ux$x} zz!QElLwjo&G_rnZHHEf#+pSM+4J29!0}YI=U6fBjRJ%V?R>(9|oxM&#=&JeOjhhy< zV@+B-;Y*3tp-3lLw=*KWCqA>wF2o_XOf!8W{Q?p+g+Ez|0|w6n?xe&@aRy{34&Uf{ zhq?d-ja_(Os$~S>_WyjHm5P9ng{^&*lx-FuChc&`8a>3D<%D(a%Xf-r9{eU*5sv4V z_BD;tP@p2{WY;y=Y)p}`y>wVc1yRlv$4!riV12Ji*7$PBcqQDlX#0SP2boAbr`>O39{xNq_UkRiLlMKjbb(KFhHYVkJ z`>)0MfXw?}PCyTwW{jQi2~}h9UjP z;fuA~2xV~DQyzK*SNU$nUJOTMq&j(j_?tDS@qhaui+-2e6vwMp-T_s}VnXxXX~v7+ z=Xtxg2#whuk{8qbcs7#R?uE@HyZiXXqW-dQxAC{I7OOUwIrOaTQlud)I41(Of z8Q%J$K~3*+wLLbZ%vcJP5tQ4j%g$h`X!55lWw*!SyMf=f~5!`=96gjhE23*Joa6Tb}1gU&SWrvy`?@ za8J7YcGWLB$}kcIMI*8KM3}<^7dM?5e(TTisUVsE&lSD;^<0)MgOOMRv9> zlQmvYPqu4%5+HqaU=8FzQE*IT8V^I-3sWEl=ii;J`>z8RZKz&`TAA=Lfs992a{F1; zkFUycbsMaagF%;LJ4&W08^5R=VdaI)0fu~;mnVmPO|PgwDu)&Qs$q7Lbj%EmQ@V`FWDL`)eskT6P=b zhtO%@Sd2fa$#xe3@6J>=mpQo`N%-si1 z1e!tv@Opy(hv*D`=iDdW`+5Sh>#`7jxV>Nm3GEzWA@kO&KAmYtEdPY|%x^~Jt?KHo zoMANAU4$71!p+t1!%xl_$pJC4PaRSBd{Cu*-_)I0DLv%x#wp>$<8+~f@`$mO*p#VJ zana?G>qKV!5~+t_p)Hc`Z zetS2J`@m+to0x$;Gn`4)`?Es5L%RSXry3d6M%Z8!=sg7AaA_M4CTXU)UU$xiKO2*MjQquf*ZyxQI`(Tnb%VN4s$v(*G8kOUD`z6Lr27};u zGE>m2r6G~}LsD|OfwFR6)@ZWs@rtxlturcW#LTn+FEtDI_i}z;Oy5o>CMIg{&GF*BZvg9C!y1N?8!KNN0*(FlG+2L?%f*TzOBM*qD)e3Sh} zrH-u2h*|YVSeYhH(u2$RVppc7Vj&Q;lA9t87t{<_?zt4}uIZU+mhZO3CHyffn+~po z*c!^pt8J%bdtsIO#X`>eW5H+*hlf6L;o)fEvK-U6zR?XGmK3Jq6XiEj%jUUEt9k+;t<_hc6%-xx z>bMw!-~8HvpQ855M)~AMIdr20m`b^jaMkeaDG8ebI6^RsY|G0`9xS1a zdTk4X6F-~-=fLv5Cfs3naP9fEq(Xzx6#|R!fhJn#u0x7qYc{ zmci4iF^(T)9+mjT%q(L%>hW|i$ADKHtL3;uPs;Z(>+4tq) z(jL4++&p*4+_VuXw@IaV?;fHma$+UKFmoq?H zV%z8$1*P27L)C9L*Bs!1hi_7jN}Vnem#RN!27rVEH82vSWyiz|%~Gyzlp%@5Z%zG*O-d;9x!f*Vyr-f(iw%6JRsVl+`9U3R05{YGKkOdzU!q-t9no8!I0m3p^hb(13D z41|!Io3VhGfbj6)2m(D3+ePVzd3%p*DUnCslaZRpKR$H?tgTy>E9VD)`MmCSomqi3 z`wavGQJF4-r{u&!P9&jjr|VA<2Cv1&Cw9+m(0K=fFo>yT!w8o7ld+cK;Sct3=Y~aY%8;b$x|Kqma_-y;Sj! z@>FxAk4G2l53Dj~T642@I90w(uwSI69NEW0j2M9g9E8=*>mcN)KCOrmCImyFd;^I24yB7B+8ruvLyVse1_UIm+QCR2p2VL~1~ z>6yvx5`*}tuG&r-$y?D&U1$74+O&PXSj8)6*5<&G$wLmgBU8jMQy9UmIGzs)TiqU_ zEKZi!7O8V3f?48_DG@mJo-yOjQfY;rP_|JPm~#gB6k2f=MGLX`(8aj0DeSPgU@EdO zyes#h858_)1WA+_M1sVtHbPM--AO;hADcTh8zX)kCbgd)@AjgMi7{Rp8mGKUssbA} z{wOE=#gM3mUtf?PK1zfIdw9^xE*JYnof>i~boqdbORMGO7tkB^_KovK!(@O&6^DfbvAm(0S=%bN`Z=Njwvf4~y@ z1%VrgCg!!;&R=a`?U(yRKtN#fwcGD>Xpqdu_Gs41ea%PHO8fAH-x+*)Eu~$nF8u#! z0ettloueL2`l7>~ca#$&!ewZu~Yq5$JNp5ra=G?Y@-1`a*xnaD-AK znt}Y|jY%>dmd$k#Vu|x4<(CR#zAFo#&1V3N#eecpBl`d?nPKm_@bK|Y{gO`riSkjC z?IvuWBhN^N8&O4S8Ey)F)>9ujH-3qloBMG936+KvqWv}0icw2xy101%d+#ZVaw?n=vg!D^ zwtOOvr@E}Z_g?QQNm;@^%O^bnJs}|xe4f>nl{6naJRF}V#_ii1%#RY50+BE)OUiuQ zcc46))~lOC-0YtZ5&ivpNSG(OEk{8Juf!@WWR2U7hasdX!@C4yn#wnuSW_xL-Kv*+V+3qL<)NX!5|zt{gsv~3{_0eJ3zLHc8?Zm;-z#_)xJ{*dy@-2i@;S3uF zvxkiu3c<$)arEycr<65Xz~8u;4%wEpIM&Rlew`JOA!%SCRkTRaj=pVXW9gJh6pdjR z$?F!1Z$M^}`npX$T@M1_#%3VF%IDozg=)-908RuT-GIjX2(A$yt{#|1(b z|IsEp#5v&iv>B3#XmYWu3&o};rU3oPV&cCdBADzs5f)Tj%$w^0;~cCQiw@;0!-scD z(o1n_i_=`@5~Ldey_E4rk(ZB3Kb;gDN`a+xF(Du+%_w)6_!a`EP#%uu}R0GC~2(XNm z%ud08=XGVv^I6+`g~9s>qfK>ov+YvtXorAOgkjGSbh3W+E};Zc=Y#`-@_P}n$h?ns zhm-dE0-@@Qiesab!RV{i96~~*nm^Q36lp2j-5F{G6Pi($cv2=NC>!c277VGHQQx)Pa#bHd)=Rit%O-imho4#hFw+e8`|0Z?$-K;eu^ON%!3gzli$IXFTNXFVC0^xGrA3u@2X0bNzgZt6!yZW&oY_@|#(3-2$zP*kh zt+`!G?T=|sIb8P4GoM2xYThrOh|37{+DzxmoeZ6vQ9n%>xSb1SQ+a2vcc7u7@-@{Q zPDksnC}s)^yU)KYKea}sFH|)jM!&plNB`lhUVz#IWO*MQRQ-9I&PZKY_c|{MK_^pW z*Gubg=)S#Qy}iVI6AmVIw-2tP02JrBT`4_ zc*}ixHZ^L1MD0fr>HL}A!+_xXAu9L1G-(ftZh%mHM{Hw5?i_Ty_Eu|ViGC&(lE)%u zP9J#`ZOx%Ir^<#U!YM7b1%kw}h``9g6li6R?Xx$#1JEOLnCuw9zd{G`6Tw!vVTXXu zT`nXk$wWdW+y=V8fsVoQ7-QxkDO7qz5#zmGu1lM(YwV^HGlgkmP5 z3Ysc)Nn`Q#K?zw7>C}MWhhb!+*cKe}b;?wUM#cs49$Xa6!?=^+`U?CC#Y*G3Kz)%6 z@UMs}dV(fDTPv9p^l7NU&ZE=HeJ<``OfMq-*1|pc@lQpn&uK$xvM0Q{ug# zffu;Hg3;lw&;P@-qx10lIr)2y^?c-fQj+*zOeY6(Q>tl|f`$fsL^nVTsR)nf^sexa zn3suv>Csk*=o`8?G9F%XKe>9o zc92$K^!#l&Qov9RA$FtLauK$s+kHGFf;Z1|nKM-c$#nx?#=c+#Ybr>#zQEAIlG5u7 zR|sSP-lq=?0{CFY2}iaIWefBfG}WJQdEOAU-E$h7i(S0?`v78+bS&0)r|Bv>^Y)8V zah6G+l22DVPJ zJf2IUGfmpOChsx@n^?ti`*7;vr^cq`}C z8Lw2y4&_8UD;d^f1qUi|(ESpr!B~aOp^TgFJch}+OdN!=2pVqXhqMxYTv3V%2*I;x zdBgyBHO1m6r;v~rNE|p|P{l?OL>?s!!>DwrIf^!5D#)>MsBswat?EL`$G$#wuPR_g zhU+|1W->G>M4SBr&QZ$I`Q>h7G7xgh#E_dPs4IdJ^88=mcd}(;tU376W#h1YLFq3UL+2ulTaLC{C%=jItstFc%S-ruFovJ%uo#>@1jQnErwLcttJPyNxtV1kMB8N5e(04$7R4@;xs2$wxNBmsPD6u#%=Ro z`@>>g@26o?!#T_AtO1!Rbov4c6lzske2{8>sBQ9wxQ>u?(^DcR4&t5$dIqw}EUI#F zyk&r%`EdC3BoEAsF)h&#*|utW7+!sJbU&zl5VkFw)@XNC>^uz$n^=P47)xX-{({(8 zB}0C9KYiVcG{J}GZ~ahBa2*bDd>N6~{9^HVQW+Mjr8D!GU!f8l%*6T3?cpbs5b40S z!>#;W4#riD78aqbqpb$ylH%q-4YKK-K`q=wHT*rxFCyI!Mbmj{@|c1Dcd!D$Z|9=L z=h7Cu(aMa7ia~VN2y9|)amf%hJ$1Lu&F4wr{l7Nijq(!yvXA;JK9^t8$86iv-P&`u zH|FI$cIAG!{?OmLDlM9kIu10F|dH55f0w-;MotK3)7aC~Fr z+vn65)H^S>AtJAj_p+nUA}^Z+(0fC&TA+yg_jV817j@A~_u1}F4{VGx{-uPZ=+hbW zEeo|;nhcNK32tYCn5vYUg>b5AB6N-<`wNm*sU2zaK4hC{1U?oYEe%GBkY6d;-(dI-Sq;7 z$G0$YP7BdN0ZY>Oy!cFe3ZF6f@#t9rC?zCb#&aM;^f+S^9YO{?jx2F$sz{5#mv9XY zba*_W-JHZLuhE;R5jQG=1ouZR4DOEz@sCm*s( zj7!S88{T-D-FP40c->oggTu<`W*s3QX2IFYbl#lRS(12o-25mV009 zU>P1bJ~y5+kB=|Px8IxG8*@7L4Sf`O9#*F5t zMgjU)a_JE|CCzl*yGr;R{AMbO>e?4_SP*P7z@8BqAqSNVvY5)uBsvLdd5?!V8%C08 z!0S)>b{V_dV(~fg@XpG_YgE1*+#dM`_2nBzHdj=`umm?8{Y@PC%7+j9ZS-OwBOcX$ z1q3-yh@=>%#E6`C(l#bvzg*Tw}48Ilb{r&x9JQHG;?idRt zk?Wovue&pAo4d~fE`3o88|;ofT${&7rI(hUffZiRIoZ9v4+BvZCks{YNvG$eB4GF; zuG%yV;`PN2FTo+<+b~uTX#BYO)WnMQTeo&L9vGw=#%%gIK}D-AiLGL_*VRIu3R_jX6i)x zRLvaS+_9+?*A;#95Ad-VU`I~<)_EBxmN$sY;Ps(p;H znrxA8VY}2`aO-s3?6|YObb%Rc*Ks;uUS1A86pT$A7psk$QHdqNc5es5_t2w!@N}gK ztH-W$=dNHhdCQZV(_;MP>fHA+^Y)yFhkGJbu&uPTw7r}pX8GkP+vj$-1hQDw{!Dql z*{H@R=rmkg;4?klXuAqE{0T;Z(6H##V=t&}8J~*0LEeLpTn%N7D~gT;Ac+HzW^Ag}#e<40P24HuwpVMD{h-u$QaKRuQ+9R{&p7gOL>rrP>? z60Zc0Ck!&6ko7~#xb8+TWyPKw)!A8mT%3PIxnm%~GVIB84ZO)g9Zf6N9E;ntSL82El5xDm>Qf@$X#UZ$p8z+|w1%SsA;( zH{MrO&*_m`fkGl$?+T1mEu=*5hc{g70)mba#zHNs`Pn~sOkJ_Uf+UtTu7;T49-Kc$ z;T;y2qW7{t>Nl!bWN*2~L);c60L8UlOOqs2Fjic1OV7gEuI zV~K8t{HdaiIKR_^03eNanISUB$W7m#sS^(#ZJHB$2|fEXVwiO;*@6cSKr?hplw&{* zj?ji{iuE9Pv}ViL{Oj+J$wft%8cU85Kg|_XYsN#$&j4}N(krrLsS+?sc6N3X&=y!) zry=;few|;N-{5t_SC&2y5$or_tlF93iCUkl_++P*r3FA z4_O$KD8#3|b_aeQy)krTuI(%6+hGv2Q<*H6VQ5^6ofD$M`DCR#Oj(D#IVr)VnjZnE z6j5gDP03DT=U$RwBia_Xf>9zoyjf&vuO#kh%Wn*~6q6O$`;;Fk+o@!=KzQmV^H%97 z#mGx*58X_(N|im0JN{zloqyc+pvC34RWzls@-!)5`tpj3frhcr7?M6>Z8wcG(x$a0 z+u`uHFj@mz>%F}_)@HjUsq;tpzdIl2sHr_7<*SnBlStbd%XQxog;NLL{Q3^{Q zPwpBlf`aI2+tC`^8$)4@6ExRadA~6(XRPCl^G}UqQ!I5SOt=KtuP%=U`gOVUgZf6i z=Sn2MxCq9J#4aj}4|A)HCS@SDELWzG;RWbuFSC9}siPU#6PJ)UUWgUviiN5QW5&-g zS)aBAe3ayPs4Lc^k_3gFWKwXYn-7A*h6+z8{XzJ{T{+Tw9B4eW44u{k1)xKr1(!oj zFc&2KqvjyK?=-2)C7no^6#ib*&-hQ2xGg)jLzdHS)=Tjaz~Dvr;Wt2WV!`WnBAbmd zoyGZjb2MGILkmG<`3aM8FZp%Nq~%B#^?E@;*g5_YNH1W}*v-?(Aij}iNv;;R{*$q! zw%h!du+|xrqv~>>bL7T%Il|%1$y}#6LQwF}qKw)S%QjavsZXB*UWUy>ZiSf zuMw|Z6Qg6P1!aF5kLQG$tbC>|3Jj|M$x#x6p*~ z7OZgB67uBj7q!uS`dyUEfP2HrkQd6$)#&^t^UucoOBX|h*+LwjlbPFHkf;0LS?D?A zysSA=3_XzTLsFzZ1FE$$m^jWJ6fC8k!PSW*iB7~z=CdKr z(8&Nn7BECOzeQ^g15!bzNwssoOA6tNN)J_r;1;EGi)~D)5)(J}#3>wQze?k4e{7W3 z1X#m;mBiWA5G{k0YI)(6Ovzfh28evp<>Q05GKLGCmET=YMW-Po>q!Uf$4MeHeX?j9 zTJQLCHE2Mz;`q0|+=Kc@@|BiNuYh=;-=7pnu7IV2CSe_sTuv&cQa;;6u1dT|=>${~ z?Q5{<3<77_5R+NiH#G$WMbJ!MEJAghxXQwiGy`MjXqfnRf?q5{Ero4rrRXfYUnJE~ zMUVz9sOFANH79ZE2;P^e(ZS(AbTCU&hdMZka|0E5&nieA5YsVxcdi_>jG0o&TMaj4yM!MJW5%1spL_V0ZS)ATPJ^>&~BT<;e3n4 zG0Mk#a&r%7_L;O3_?)MnkEZWCp4~g!E-S)tOaHtSylZQ%zPXQwpt*{??_n9}+|E9^ zukW5pJ8U}&@G`aSY<7QO?YNm`C4c>E?^EjX^&POTbJ|^h;{sK{n1t^<-EX@;S$lg> zi;4U!1YXKbUnxwvozmielpTGV?6iAqBVMTZ{)+x}rf>pJ`5DD45DDW{#Bl0tP51tE zgVXNDImC>!D5!=sSIwxIh#$EaC%}l^F-qDgNp{R~1kdVOt}XVW z3Qp4904vlITk4MiKxr+)uBH_Sd zCs*UcuHFRd1s{lF)B(lZFI!ELzrTNYbLHH&lNrxT)kYlEutc3eA(i4sCmR6{7}PU~ z%RAzlc-0K__ozcfg9%(x!zI_E!#3AIm!~2uNY{@glZ>B1ca#WfEeS2j&l#9f69s&u zn3L7L>gB=3_e&!q3^mkpPJ|ee$)Ls*$~%JmMLyAL*3jr=X7q%}p~I#G8OB|`*y_z% z`(=!3!;%%rni8>eqhC0W&=ocFQS-r5Ue=ua2Uq?7;5u&oKX5&$tUq;M4!L=!O`+F? zGg<$kwS1}B30?QRX9oT#tj{NJyelnJ>TJ9-%jjs{Y%O>lMcRVzk_JGzDm?b!`$G$I zBc?1JPNzOXp!a^?@DAF+2HTG(k_>-oV6*FOM&R>z=zQ3C`a*W+;e_JVdPgWGDfm9~ zo?O^zGpvZ5p7s2G(fukE z?0opw>3eVtbzov;K3!N`Y&N?N8O(a}?7Z@?Gyk5Rn7F2)uPBv^SLf0YENFCOrMeYe zh%^8o1h59kcz)D&A^GdSFHkspN!(4>GcR7IM%9qc!j5N-6VS3wmoh&&P+cZDgsjy_ zxd(sIR+mPlTt}OaZUm$!F{aE&ie*a2!1_??k!wDrj;|UV>?lC2*CWugwYwR%Lcl_5 zm13<=zx7{PN=;fKgRAEafYDtoTDEo2RH&6%c?JnfX7#0zwN`Nj5sxC6pFMKFb3VLN6tf)uTPSX8KBKW zOopg_DLx)s!jO-Ow&TLn9ON3*JKT-i#HuGS%{JGn+7riucvln~&Sw?$D)u?RC|IL+ z(q5rPK!1%km_Y*n|03_LqT-C2ZP7p=XmAZKK>|$(-e_=l?}lJO8+Qnj;O;cmNO0H2 z2?Td{3-0a?hkx&}&&$2zJfA!6J6 zJ7s`P9U4O}u9b`w%_8YCgM9Je5jKCi5g#IF_5nYp(w&m@Iw0vZ*FGfD9wb(6%_E3e zxE`8Hqpm4KNYt}uZI5<<8XWJ>Xlu~moNH9b?|~t(me=dAy;i1> zjR;VMsR#f2YwEk(W|6llW4@-m*=HV*R@b8bZNYat0^GVC~h4vi5F*Dvu5;_Q2Ek|=@^y1j_%21 zfJIhqQX@rONrMJo)|%m8=0RAhpjw$_jSuH{&mr5?Q+2>B>y{HPY(?Zr&Vyugaf!UOpV)^bJX`v2RY3|W@ zyKa(-l-lO~qxVnjVX4cQ{nf^$_iOY|3EAw$mU*THNb*YW0!U=(2b%O~zc2B93C-4H zno!X!EZS@@5~KvF4Z|9Eh)fO6{s@ggZ=8e-oifN@#>WsIRk! zW}YrDlI4@gj%UPQ0^Q|KB9^XdeO;#gcIxk8j4df7+HIjXOuu}zr#2m2ppT_1*g1vs z!>_GxJ;hW?6Uxm0Rm9~gSGoE9Ye=j_cIn{9|0gtgc{bq$woj+RA(@j@wQ^E&1t7uL z?K(oLbQR!}n75+3le*c;P+Xi(;s#?BnE*|h@`sE-@6PJQe1+m-67YOtNxISXA*Ahk z+Riy(Rblrv`aU;CLuc1|#wLjL<6NQ%4a(UNE>W$I;ahmVusrc{ zfV#k*B`}^rIoHw`d#sHH5m%j{FEQn``cgg$WJrdOnFGiufnR z@-lLpl(@;_hvq&>qRhYC<5v=~+7ja*hu&&p#w+aw57WGBZW8w?&bHpu*rW8pq6sXP z#}**`J)c|p^~sFig zO6efYB>SEI-YJ+nCeuU3&uD`AsS?;_sGfGILU>*|93;UJ^vdcf|Msiuo)Kf-4<9vX z|B37u@oHz@qGZ939mxPFu3BpI&{JTrjb>E)X)}9*F%;Sur*+1zaAzuqZc{SSH$g)y zL=w;t!2I<}$xo9>bNDzRq(p@LM0;!9;ZEPgbY0&CO;*tH-q0!Jsoe39_rS4_OTkB* z|55Axra5JTVY0KFh&9%ixhV!AjX*)`8k*P$n$FFgi8m@5o^!H$8JP*D=Ree#;CqR5 z(cUH(Qy+86|Ga;2cEn&dR%VQAd^?8!ph&UQ(^i%QPW?az2BlUN1SvMvA)L7dG!L=6 z&uJl(V76h1Gn-YImgi~;VygS3#ECVBk8DG?^LxU*Cng{A%{^z;LS(St(H!d&Bh*F7 zRstwj1u8eX78qSI$rGC|+t0e69F9Dal6Lxz{A##qqXi@>qqIYd8(vY@-kTD1x*sYg zHdy!rdCgRs5WkYBTf_*f&ZfFdhvJ4b=m>0?1mujjQENyFd(Nx;q_mm%l<-kU&_yE1 z5S)%aiLa=wKbw`}Q&X@!YpJy-%t%tI=J8EUZG9mAa!g0kHJ~>#GKRoVjdqUSDB1-| zS9^u?=C#M>ymH}q$|d%-H5SA?WwI5_38f?s9i-A02{-EOUc&dOB~kIirrHnrWaaWn zL77Y6Mc?1-emKL`pg=e80cKb!x8)%#8_J?QA#OrAk4wqRBc}FKHqDM+-sxtE$4-9jpAvUoE3Ac4@&36t;dMLw=E^8t)(Wk!nhPPoY2+sA|NEw9 z_VD+oh)2Wsh5J+SQE9ej8^DD%x^3c7eJ%DRgZ4sMbdx+C57-DLKWuUO1$0CQ?A^7> zI{1&v(~P(Pe!Nz+JBCl?D)Kp@Inqudn(zbf+drs@Tq*6vzS$2!`AXxIz!%snI^j5F zrK5=;Yuw?%6euqPdPQoKp3>jR$i9LwR!J`+*pdW&5bS z^KT}G)}47!O=wN^`};#>Neac_ecUY7jG9}yVl0f>-*fk9+BeP5BI-Xdgc1_1U#ja0 z5Q-i2vX{-f6-h5h^o0&#N9=9xqCjMxrquJK5ew-k-vn-kFtbPd=W3En!;2Avx&kg# z(BtL*Kc$l%KMfJbk7GhO(nTnM@LDi zR7yMrAEHtB5~DaBBN`=<#9DP4xG*LaowMPD!zUZL?Tw^HzF48Z$x2b*xP2(3UxQK; z3%rfVZ&V;5L@6xe__g`B>BU^XD*hVR&SujR~(qM7q(y65C|Fe4|qr<#Z6 zi!9X%(k{9(v>>Pd86@ecVZbC&sY7w;FNJ(bqgprE=N2y;X+|A4pR+3TgFs|tu-IgYGyw83|hM*1U-Rmw!~V+5iS zmH7LxGm8Rh;QLYY&Xzo}o3hzmFQfm`L?Eu3+I|BHTe;+kDKZw6F=K){m^_B;h74p8 z`aL{7Tg7$D`r{PK$HyzhIs4d!AyK=}c>>uB%hRq2B}|LcJw0q>?9_h|6RSku1$3#z z;UFS0qiwNnu@=q#5OCTW_;WLVv-IG7lK1_ecG%||-5+JOv%njl&qW6}=S9;@jf+=I zfIgG@d>}DIIDVb3_&!CI-2(A`vqbiT6=u*HXK8BxvXS91x!mT1zb!R;F0rycHkM+f zM_E}3!-AUNm2A7^C(!Uy$S#cFnz?2aYvEX2TJY0dXn2`r2$>DTF{FcQhfs4AO(UAl zgH-45mv8L<@Z~rea9#&8>S#!c(?xhj)6sGa;CyGuCBy(pB9#&3b`kP)V&~!z~5i?hi_CV)wI0dWp)?<4a5d9C}1OG;TqSUvc^?T)19o1RTgB? zU=*&;G(hr5@lnx(7cGjjf8$F1>BYq*HAUsXp9n)PqxI85S6fFnsEi+Y`ot8Y7DE-k->w8`3RQ8)}FL8n?+YiaqHt!t;UI#)iFP4xHNjJG z?kF%KtXnW05|k>EH;mQuCpROVYjS3@DNRYa{Z+eoIM#nDmrqJQ_X=C`k^0J_?`psK znVobKv#qQ+5p`D?0s;ackCe092K~PkUAC~G;~20H6x>HD?u~ttDg*xxOmjM1u$kSz z-bvXIA`s^Y%KsRG7^jL5p&V>X3xxS~7!wq0DY=l9V!Vh6xH?x6!+XqjH~o3qmt1W? zQ#|rFi@-m8LN<*7umk(Vh~gCA)+&>rywmzgCfJUn9N)Xl@hy|G|FJGIdy{Pz}G5z54xpVwSidu5hn1VPQRiT?-}Re}9{D(*JTH$nJ+05mzSXV0)&HeM){Y~6tk)sE#qWm&%TsVIAhbQ^ zif=6JVPL*>RdEvhSl@2Ul>$&2l48I}?J|zRf676N-RM=K6mWd>D}$U-wd;n>pi~Af zmd;IormI@atlTmAcT(3uF8FJEL*@@&L$=;RV#d}7Y_rs__8v=)e!pd5(cllkj5+Fm z6&$uKY`m!VP))6tbE*&WiZhz9>@tYBucFUA4Gycn%1m7vt#t@20Vo=E_VuL6E;Y(b zgZ%xt$l!x@fsMG%inX|O*?r6z=%L!)DkuRJ_AIWOFhZR>ec|6=lvW&*Y#q0j+$@uC+wkB`I#WB{3 z1EEMkp7@MZbDMo4b-BpD#w!2tP_dP@+#Wd;*Tz%PZ0U{>pV3yUkjEX9U^kT94u5jr zw`1tmOB3D@5W9}HI4W_$dYWI?9X_51F7qTtmE2&+zKmZ56!o5kaM>Hle+ZPaT??Fs zLx;Ecl68*GadrqkRz`v9m?FsQCOXb+_L2w27smHZ+i->&Dxe;1HiJfpQSLmv(^hWw zbTyk5!w0shOC761RxWLzPR6Y?-@yVZj!rxu8>6+od5 z1POrt91{DH-b{0x)Y||R{^E{q=Yx*0Y`;SvYdw|~WyXZkKLdFssel?VjlpuobRzg zjFj&MVjFM?t!8av0G&mSzuvon;%ax-2)u~tMko$Ob)Ur8CyyrO!hOJ+Wm6aYzov5{ z?kgY#0JgSe1m-pi=C=6!xT-6$f$mu*BX2z}i7x|0-zV4qP=2u(#Qd&=Dh>!A(cocG zqCfUZAcatG=JbSnOc=*}CA60Pb}IBl+m8MhD`htBK}|e@&NdVte(yY_?rX)n#qwKI zypIS%WL;$_CT?B6Q9ymGD!8T_H` z{YN%K!>gp7D@ z!UGZ9IOnPx-GY{1WU|of=s(|P_{16_8wL+S`n`*2dT{F{79_p-w~It-)6NseifO)w zb7R!$u>YpvnLSg-c`VGc(dRRgb0t+9Y5B{qd4i$lb^0%@6MurQ_9Ou!S6;r2 z9^KGPp4;!CJ5$8530M}@;ZHdoB>dE(BEAu(5xJP6dzZuUHU+;!WiXpFaX1V=RDjB~ z=L70p;hAv;oP=Hl85~FKrQ9Y{iC8JVUAx&jg|iU?raj6Qrns2OTu{7|vzW%I?oLGp zvnG$>q#fKcqxwuu)Um4O&dro{ScCH^{ zEIk^Wk(;n&lhWZD&4?pgNarU}5%m}`je2FfrHje41XqVeHV~IO%1j?DhJ}VknAOWt zs$~&B-Rt1J9o8P(GXT)iUozb8`Aq3e2b9pC4w^7FV1SuUExYWFwU7tII2!pDT|3QU z@r&EPQBn^GB2yQtZ6>Jw7%?)K^4pwWgg90Zz37I!f;orwmlsyWkC$Pft2+LzE3d0m zwfwxSZ6GO&-Py0%dML)&dYUhMJ-xq!l$9;!Os?%7?l?@cI2|+Uzie=U|G(I9-y)24)w<37 z@wiax{v>8$N4hReBuXK=v)sX(D!KA%UdD!kZ%t`GbTfZ9?nVW-nY&|cya#LXjHwFfWtoY{g$nDbXkPg2?&=|NBmZJgccLX& zK}jQTUM+^fMfEUz$K;t`@&MY_|0CHU4KNzAm=HoNqns-?g3n@HFU7Y690Mv9;7E#w ziv`vTay9@3N&$L@3wVgL_;36*iAt;BVh%g9t)R^^2SLheRoQipQI8-^bl!m za_0P$uBZT#Yb`SShqyh^wuz&;F&+wkqEZ%Q z(~{oM9+b#|$$~j~0GyQ`y;bif4)R(x;Y|x7=OT^xjfS$Eu(TkSIvKG4|0g{xWXf#QUutp?a>zzZn44!)@&`m?xs;*Go}h$gqcat$VV_!mmJ`lvqfNlY_mTCUx)%h z>2Qq;$DYL7y`O$`=+B1$C8IAnjP6l%T7^!WpJvjn>3B_$}(M9PD(JC)1SeO+-BNkUrT=C*zgS&Ej}{j!a7}nfBY+d)@q2K z$d7P!ouS=+0G~MPe5yl0f4Xlp<68pIFq-g3)(O#Fq70=tkm^WB5&1Dx%5`M3lKn$c zR0V!JbJIYCi#%$}bQlA72Ve^R%;q<)epBAi#4LTLQ%~ME)rizfbZofH;R12!*_QPN z?Zn67j;;-5^;1ZLUDBJTQPlXagJTSPRg@ z7dU>ZFcvifd=>M^y9)}3 zPT`MSR_Od-{ftijW2PVkb9-skX)!*BsEjgQfY&v}7s2|;R2QtXc58YaqcmEQcz+e= zb~mIL3Ope~=Al|spr+0g6d{0LcG>0*@@b1ncvyGYtm3GPHPMXI{IxGVn|=ano90{x zbKr7|-R6p605q-_MdQK1n+2-sWAppx>ST04XHbinu;xhv0D}YR!(&4_jgtHO4pV7SSW8S2~66Ih%|Zr4s`( zl7XDcg0yv>96!k^mQb%qO_j`vtR!)eZFHyfCv*y(O4aLS=+>>78I4*4$#@Merrryz z=fin8&MVUf)MMh%7D-t`y_A$|rm2@f3;@A}iBEMQm-aGshzR$bDrA}@!+T8V%6a2B zk_j$p|Z@d%r{t5SudOiuhY zDlC3X7K7QxaXn(`Ch1L*A8^g(foKU zfNZkeTF>yTpi5^|19Mqqgy3QRIYI9C4%Z$z+RVvGYZgOq#MiBZTpRL!ql_o{=TsCi zxn6%tY5qR*TwO)8%(I%EyNePc?z{n%kQhEF=QvpK8y-|Np9X7f8%DbCZ$lY&Ncnje zQ{Pr`3EP+pA}AFY!T8XP4wX8!eBvh~NokGC7J?m#dU~6;6-4Fs5rrA-#^EL6k#_pa z+;>b?txCCa?HT-0GMedJRxSoex(#JtykKe%WA>G+sJ(|b_0|9%jA#Kzge8uxf-L;} z!Tye^_diPCk4Ns{;Br?q50pYMcs}!{Yuf3i_;arprGZmn=DSPQ5Uj3=TX^~z5?^|d z!#YzymB4%OouU%2Sx88&71e+)3Drp|Co*#p#%*$o}F2k9e?a6#l=*VkUR;p`n0fNVU$o7K?nSM&iY>XJ{yoXgh61*AW z5>qY5pK8hhb1h5GT7OJUv1URFY;7$aGk!bOlmT(%Hzr+s@K>Eb^n~W%^c-_pZ}9?` zlWAyYioKc4jl5bQBLum*Y^x^cM2gM)Q@1nI`ti}(+XoWG1cyZ%AsW-NIa)W8F@cL8 zI(wOGvEQR$MA;Ad_zupzZdU^x@n1midAgVBo>JBwA0Dywq3$BL>)ZX0ACtL_jBSKX zW~(EGkjVJ04Cpye19Fm~@%Xi4f<-py#c!?br+Qsn<~Q=aW-i^ZuHqc!cl!>r?VmUC zrB(YdSm6ADTx(+>u*4@2%oy2i-@blIraW)(xv{tyLS|#mcTGPG464_kbt+2daZU(B zlP>Z+A2sG%vvQ;pns@KeqTUY(ee!2uw-J~eiIE#Gb4Y8r`C<{9f1h%}Ei>S-sW0loP=4N*@YuOiIxyeUvf|?F9 zONyK7#$AIQIIq@}kcc5xh_;3Q!-pU51m3CDiwI0c(BX;LH+3 zd44z9;bnx@zNBNO{m&Rqj}~r};o-A6uP>fXvtrCt&`vJcV2s3)RfHcN1NQ5 zvb2>s*W`sIqTLR4K8{Ome(M$cd07?S%1XZA``G@&O~n4SmFEcNLEd;eRD2*Q%b1yu zlNx|!A&Gxfhg!_hN(-|Hq-)4?oN(BG2S7=Qj?c~F@8TbhpJ3;|Vk|UzQ+I2oMCg{= z@RmCo$ikshZzjsOYqmfett0{&gZy*ZDwnUfUaHn2)oR9;ryruweGYY6>4>g2I)9x_ zO1PD@-Tv~6P?|3ke2!jgj~7OZ1?aOdBBE#NL06P_x?qlT!uhDx)rjj5H=|a7oyyc` zDuYk3njEm2RV-O}Ci((8Zi13$)og89r2DTc9Z1<1;rR(62QrcHq02%XWZ28VQk|5! zlMb5G>=}t)CZXLjWn+$f(rE*LNeaoBFWF0x;R^)6vC(mc0t`Q? zR%RA`BFn|s*G_T>SS$joEq2HL&7c~|*`N(*_R260_5$RF@T7<&b^^e(*Oep@;L%>2!avPOssYmXf>|Gy1z3L)#b! zbia*y=twUBTue+DHFO1>%Z~w0m@QS19YLQ8H0VneCmthy>-(`U)jW*VQ+b{L&7hOU z&WwnrOVm|%(J>e*^K5ArU^ z^=NBILGI>C9vF-?0Deu*s2OHYvX)Err%HcrmYTF4YPWKh1Ad*<9E_5 z#yMM3x!-x2wpc~xQi|hwk`27!lmr>3_s>ejj>ABV+l68*s-`UaSV9;^!br)>LL#ALv2n}Pjy!e-zJtnots4_I z605qXRvjeExWtqfL|J&bt`1k2H=RSe4U_${NcFa%RxwaKwJ|AzGOo8u zKjUqvH#-mc{Z-V~)j zJR!!9r_dr#)*dmTi#CxUtBDXkw;njAB~tiF+esaZsutE#QKc?*hSG0)VvE)&uqH!7 zKCYYv#F$2OHiAQHzXejZ^5=?+QjdsfR=o!2nMn`)-Ca5;#rauN%O3yO&(wAxoMW0e z5V>rMDW@f6U6bVI1rWeRC0P{x@U!TNigJ;_*xqvEV(OXbu}j!$c;SspFYz~o7X&g* zD}Ew#jz8Sw`x77&pVxai^R7rHjK)hLZF{AJ_UCAkRkLu2I3{-_0Q3GPvNhwThyQtWY6GduO*m zu;-emRAIr$hZ4<9odVmsW;h{*XUruI6MT8tfjr%0Jarw<5op2_-%A3>C9LR187V{# zE`$2An1?HOHWDjvSaJ0olP@_Q19wdMOh!nzgqkaBke#LonPIUh%N`yRV_~nZL3*Xm zbzl4VGNJ8W-~s9H41P$dmbYnNWZ?@yAXIxcNr{pR_ZZN75Ung6kl*h8;` zaCF?czYx7{|G%tTInN~A>9v|N?-GvwBrD=p69h(llK;4aXYF0YTHxzOkGCg8kR%lA zfl{!EHrve8m9UG3^cf?|j83PQ%Vda z>}3bzyE6e(HSfZmWCBsYqM>i=%G90T_>q$;hmznDl+I6Qzxu{}PI;C`^tMGqTKc`> zZl&fnAtdAYsNbL}2U=zlv?~rh`7Vz7ujZE#_AJ3d!TVjNg$ISG>@qUquC|{hXz`x( z>F=mY)uiSTCp^moX0o&>q+>OcXKKlsxM{9R+t9qBC0REIQ0OyY+D3G85S_VOFFbcq`d}YJPE2f-!?a%^_bPxH7 zXzX_uxLkxiji$ao(T2Fx(22Io{xV}#UQ^K%$nHa5GlI1|N2%fqYV26qIFjBU_9u(? zW@z8kGnYARt2hmHeL>j2gtoogAV3y?q$kR`C;Q9xMF_*x?6W`77DfNE6gFj0q1RitcWm4mgJZQ$qw`hfq{Ke zhFK@M9_7AlwIgTvn}<->bY#D~A&_aYTe zf}cM`s1qEPx@dh`Gs0q^;k>t(YY#0B$rXN=f7TbR#q9^nkNov2@D-X|CmT6}xOg** z`8xY1!obM|y;FUjPC{#sPm?M2^iQ{jKZ!)ujbPwwl;pqG66>7RO7cI-aW$E-*g0Ir zs73p8tde>G(z>Mc`Sp+7Bn3=!x{H^7*4xyW z^p;C^7$esR5~}t@?&1|wy4U>kWNX+v3DOR0B@Eg&Wu3I2zti|g2Oyih+$gcPJ1~+T4hd& z>YKR|cd&m3>^oUq9ZUw&_>8-^T?e8UdEL7H&(u2Sg_;2m$b}hQXiZKk_*1dUezm(B z_H4TFTPcffb{GRwc655EgB~T_0V*g@_|WQi29JC!k@{Ix?R_uoYXbWWxmV&v2m$$l zZgSTnJEQL*b+2|NZQK9o$3M8noN0^0-wLu^&d%ct%ICvODIn4BBMlYiFVsBH+nVyX z4mbz?X)_!V4GmOnjbcz5I}4rl-lf_U-mx%-!?_JZY zwm6F%3p3(VRFR+(z~5*Pl$ybX%X!aDiwW64X#x4`CXlOW;4dblQsC{U9b9IKW>D*PgDFGnCorlJ=X%8d8`|L z=Fo`P+rRY3GxFH|t9BU-G(X+w#iCehb6sAr0lQ_S!f#QQu8&rnPFH_z^hVpvRXX&( z@+;MSYcL!BP(P2)7){0R{$}NCo6BM#L@N_}y8Z2gvY?ysMLqyrZ)@NTnwx#C1;(bB ztG9i+F9Q=xQ3*c|z0&by`mWtcc65;ORL4Lrw7mSPypHeUN6Gms%}*#v$_;$mN74oA ztmj)j3G)}j%>zGL2~qf?6Tvf?$;OIwPe#J?zDNb?l}d^N;nK)#vp(F`7`M0*>Fa9&|_gz*lkiHCu;YjsO|q==KJe1wg=Cpb_@K3%_h9u&{0uK zRoIZ;2!P@qo5ZV#9OSV4P@&}TLR{wK%MQuh5`mA=HIIf!Qtf9$Ald+L*Z z*;l4~yw}E};J!N2o-w`GecPyM>#l4v~M27sfR{o8Xn()8LoW z1&H14-v?i(>GjTG^GxsUXO!cI!=2;j<5&8ZheH~{C%t;tC(|VTvvRk`N9ES%vi$Xj zl?K4}r>=Y`snhdb=UIgc}6kKJ~Wry3Ej{a}%XgWhj9(f=NPr3)CC4Fywro_6{^ zOY<(IijGn8OdDnj;tb8IEC)^A)=Dnf?5NUjFm8h zBBjyC*juMnhfTfK=ks+SZ+*nh5o%NEi*dM?#8f~Qi<0kq&~2Cs=BLsyXd0pP++~Qg z(WX6{Vs@AS?fmPnEPf8IhNH7RTt=mFS_uzLh@lu{^=!D>P@$DeRK6JM!t%yGX0N{k zIE3L=$L|frovQwuH3Q}LKiAA|2t0I9-ywy(R{>wK+h8<_UTo$YiYW<2>5t^0fh|RS z7LJfoCL0?~ot=>MUAL(&Baoa^O3VC+ z?eP!wJZQV~HbY3+`w-vv{x|l^L%6HH<1&j~tKn0Sd4`bl<8G|^S!4`G1##2-8H*!KZdVn^?wMBJXW0u-)mF7)C5!A_1Ha5zO22k_<g>4Kl?dF3mC;QZs^pUhj{wj1B$<%cPj@;_(AEZ-1# zz2vu_-#*-&tT=fJnAdr&{>^B6Sl^kcSh+2)f9Z%ycaOWhEqK1KXuT?XmCz@ir0;n* z+H&C?jyC$xe%veKG-{A8c(?R$%6HzP>27`Buk5|KO=j2p+(ah&cgcn9vrhBvobzY~ zSJZy_g!a4Efx02m*wB+-uRL61+@d2GSUv!ksoX>C;2VEbo=@nIZpO9!pfD#jR* zsM61P1eVr~cO(4S)6zaoSep z^BOakLfip;alKy!7M$^$e*s&G*;|{T{|K`rix>%1hNQ*cAR7sl+w%aR8wErN)*XTpH$;wTF##M@CK4{D5{<;ShqA5Mu{}#-F@)$rpTekebLR!fbW&M zotHpsC|*5Mh7YBXqkC^)bDxO*T;h$}u%@_OBX>jW-{{K_-#J?EvjwsZTDm%bVfhgclM~k~YD>#Y@wtt_KxB)d zReC-nNv1EtN!wXN1P(}S)4{8bME<-Jn*shE2`W29&GMeMEmRY84)8OGV|WhV&O*o( z4vgS4{fwiSK2gw~nYs_URHHbTlr@6tqMu-)rr2dW!OsCBMGg#`-!lQ$jUA7Xmi{`} z1Bw{pukrrnu92w#S$rNY4u_GRk< z<#r~wke(0#wG}u9rB*IPI2lm|SSoAw8!(j`1JF)x`(lH0AlhhYw2Rn~mp0c(4*Rz! zo|?zKm=W+FRM_K)U>%71W-1I;&1y)@|H0*JD&c|)p#IRDf2{@TpJeISA-YCQvq&dg zu;mX^;AD(2pyhu{(#Ftczu6ZnK1%L2UmkATyoL4nl`ES~Fi7x8$W2-B?g^#M^(8Oq z+nLkJa~|04uBtn-bfNZQ{gtQf7-G=;vL@Yn{inChX>eG;p(fkq0pRw{ZCNq1;X&l- z%3S2`m%)?bGbG}g9{nYG!{6XxUz@7o#K!KS@uuyB*Y&uXYS(9N{en#&Udt==@9}=g zp||97+vUKh*L8em^UY-2qx_5WRd(~kxySR}f!k5j?S9F(-^3U%cw3#(g14LT&AuC% zkM(xmgUx5+>m-dPy&eK?$M+*{luxe?#q7<>OWl(jbd`9so4ch&8xXB{Gi(*t{H%XjPO3DX$K-abU37 z`l;R^utn&TjX21!TmI5zE?+gyhSl*cp|IO%d?CR|=UCxa%G{>P>F_Y>EO=2yVjZHA z$H^ZC(N4ESA)ZOu06sQvq;5cyMjEkFY-?h2y?%OKL8`X$7P(;nwODB|^#R_MaZ@1g zEaH0d)8MWdC_*P$!_4&yNL;`&X0Z#Z+$DV`%7da#L--9Tr%`;p)CW+?K1kZL^{R&1 z9Vy8a^$cPsJvdidvVWh=;YI2)ql^0Pn6D+8b1v z@J+0PHYB`HMnXB`D5W4Ys zT#Oc~Uj6w}Z?C%qeM;na+)NSrk@>IgsRL5p_%B0pAtI8DSKnicn~bN_*K-tGS=h}H z7q0Q;tyGq&K5x_}k&+)_I#-NMc0BQ*s#gs0x7)r|(@bN6cUY~T!$xRtAVwKJ?#n9S zn_s~Tm(Io-u|2d`dI$ySe|1+VEQ#|cEvBw21$tzcs-4&@><`9WC+>Hmj2%6jZl=QU zT$gzozl-J*{>?!H*3kc!Wep3ZuFsaDl6;Lx47j<3V+tw?xN`u+dEJ0XP>&9YdSU_m zTO4!TVg2xhOAQKO6K<&E^a6qSYsw2Q`5V+pm+jS%A}2Zn{UuMpuip&v--hmI0j8`h z#DaGxZk=;9FKNuhL>DQKm^iKxl^GodGk*7s$1`iB!#j+iL_YW4*O8C0*E$bY)Op*9te7`xQnCzA=4Hf zCcLS@WU-h-A=&e7b&uY5o^y)NlOwP=UW;#NV*S$jq=@t5VH4eY9CC`*n`nm3p|}8s zDkcWMrLYg)1h3%|rj=;`$0?W1OMiptj&tK_O6n#5fhsN=ZfgtR9>>7t&Q!bg2yk*Y z;@0-m{Br$ryYwfsX@3A>+q!o>+j9SYRxi?*;k?BhOXQXOs+Fv~+46$uIVa3sr*_W| zeb zx7q4#yYHe|+wND1Bb}}iD;ln5wA0((MROB95{tG$9wz)UTUs}+seVk7{kzT-KJWGQ z9__7YS>K;j*II+r%f9-T53$itXKs`~=33a=)~?1AW14mDdI9^bt|qgn9s2k z&?4wErwpk?z*4orFGA>74W;yT_~c^t$DOV3xG4bfhX?9&iuKJFQHo=Yt2hR3z)`W* zL4rbMrmQTh8om(96ti?0`7gXHRUZtLV9-`?Q#?^158?@%Q~MN@A@igEg+k}$1LV`r zP+p|P7uiWfybYE)^?Hr493?lQT|0)_maS5RcPOzq%Y4fqv$P3AF|P&Z9vO^Oh&Bqd zsM9BD$dyBFPEn1CD;xJ{7Si{$K3X|Y5Dth%j=;gucI>sqqSTV3F9K=WiXtWKl2dbQ ze@bE$!DT>|N;*rzJ7J}x{kPU#)p4%m@5sN+gL>?>dHv@7tAdm4PtU0XKqpH&kirD~ zXV)$I^%Nj+y>L<=`#wz;UQ zEcOqR)8&UbnGog~szc(SGyyYVc#<#}nRA2OJ4h-f2c2jwNIIYPNkOxbpL^WtYr5l9 zFD{cR)mFW|eW9jL=_}$Ebf?E0#0oZ+=eit)=I+^D`NA@^$wY?h-39WSB)0Mgb;$`9 zM#@D5{RaSW-ODA|$80fBp!NQ566Mpi_-*afjn_54LE}sKt^Qp7%5~~aX6FqCr9rLP zRqFKpM)K#?{*Aj5?M%M2$Iy-Iqkw{;dYeC|@9o?NiJzY@GM|LG+wVV9XKNt#$9c-P6P8 zCxDOm&4JG&V_DfrZ#31@u^W|u&$>okXK&P!%TCE``5#nVqGy_Z^&dc|;<=IH>7CDq zSM>E8N)JQM6kOo+V~y8MzF?9*pjBclv*#)rdAIro%R7Qf6N_eL6L6v0KsH`c#;G%9xm{zy{SvU3 z9_M)i{oUP;x|lW2&!|Lwpqe!!`da(XQkOGdz4g{BLD}k93T{R^9Hjw-{u*hrPH$T$ zm1?RAiMF-do19X0$!hF&riK&KP2rLGfA0vl=E&XFfW=t?2! z{go@n9wt7-?e)zbVR*laLbs4+Df#^rnhG+ab`-E+%YPgBuxFN#{^(W z&Ap490XiV+MkCf@O)QLqs$Pqe0aU9wUTRLf+1!H`uDJtMn9eh%2akjWm24Uvd(4Hx z)G?(EGu-V`d{N+HvV#N@k9X?xDmr-#8tIi|bRQe03)FWMdSW|)^Y4L@) ze63R5m~C_nDc=q0MDae}#tT@TGu^%7J|>Y%>K!X+Sj5P)AopLhsf7c32ag4OBox2T zlc3NgW+sZqu?Rat1DzcECrJOb#D>c#R_@Rv-cjX1oJl_;@MO+od~xOec=qL%tnJ0W zm&*I{ciVke=6`}nt1I7Hu4bBNHj>or7W@oAk-}Wq zI8L)vwB422`Rttdp1Pfe+p4HnJ3Iwp@YwRZ9`3HujV`R|{V(#~@+qz_iWZFEBm`{& z0fG~xacMkABf%ZIad-C+2o~I(;MNdmf_7sGZXvk4yIXLW{=HXI_3C|@nh#S`HS-Ug z-h1D^_nfoVUTb?T0$cw*E_YnX`Ck+UoJ|ld2R!;Hp=Te3@<(<5&=Od1-v1u3M;UM; z+j`ZnvO3TCQ*cu3L8AwXG`vRS`VR{){d4=~G~aX?$yzDqLlG#D4XQrTys&RRulkMx zCE-vcRjbQ<{g5HIiSWE03QHs*rHf1zbY@Sb-giyF4?Jj|%>xp9ti+6Huf)fvTzPMTMtDZJv6JN|}7E~XKdaqw#@OLxB=oJhpyiYZf8)KyoAuuY1X7}f~yxZ zX=M5Z=X&2zOA$KXxw%9&e*=D0jtH@k>0!)pW;2szJI8LZ?wxI%eLYU*dN4^gpkj94 z+3;JTE~+uh+4o@J75nH7TPJ%*ZKe&+&0j443Y1)&B}(x9wpO{e(^AXBHyy|SP=@JB zgS)or-}l+aV!j(Et4gSCWcl~{tkoM7%yGexrawf1`G#t)c{}Vn;m~1T`2qiGmnm@D z0o8QJ-~VQuFcdV)J{Wmj2=9qq%n{%cQN|$E3%1 z!!W(cT#DyoB}&kxlO_Qv@8f{Si}qkIK{EgLtIO#{_uH$nEL$qE`&a(T@5Fm!j1J*` zA+xn!!;fD+vo{#?9p8)fZ4azZObqWx zw}0MIis&9CbmxVH2+`O)_y7dH>@Dr^ZYSrwkY}8MJWk!XlaF=z?QEUo7=Z9zj(ozV zK#NzLpoiusHhqim9s8suBzKxK_tztYv8XoiliQX#xglpZzQUwTeoRTEd02BJihN=& z!={Mo_?8&QsZI1e3|lKbNJblI`?(BJsxnUc$Yxy~Dgu{F5FwSmniMDh>udL*jfg^> zFYSIid_wlrev(gEtfG%&S8F;D&sqPF^KAHDE(Xs`%xoL|VE^F;%Z!h} zIbAiB+1h-v0Fh=Jb<6u@br?nc_l-vq_q$2B8$A@ZS0JaHESleTV>WKhm$=)BloOjz zlPu8T%6>@oki@w=N6oduell$9#mXhb*k{kHc~S?%Xp~e}ERMCS>SqKWzPSl|-oJf} z0pI(kPnPQd@BQ~!f3|bvJpR!+{``2e9do(hHY|F!BXH8Wc#ZhO%HHhT{_kLTT}$WJ z3-M>MRteV_@?v+LPH~O*HLU!V%id2fuSEaNAnE_^&Q9jeaCE+Nn&g{qIfE_V*A9z( zUO8$O5893h4vUtJNi%Ro|2=wcx7-l$^cyC47$)ww_HOxpZ|-uxvSn>~PYSuC3+%g&EF(q#7 ztqLRRL~6plNAtVghl=uzqo@;XLVaxc;j3iWL>OF+JRB7tOrE>8wm45|n}h>oPy}3w zql~AZaD6Z@oip0@4^B^cvwXjxq^J%!lcFe+Q(%=k|AeCszQOWbHchD1I&ez8-r0i+ zS`M2Uf{`?Y3g6T*oVAHlO2mVQ!__vZI;oUhekdH&N}p(QO1-Ba? zV<99sq@#5NkywsKVX&81fDTcZ?3nbihvz_qbG26hM;qg*eJ`Kf(Ifz(bTSE_UwsC3 zmkIf>8oFC^LBFlkPL?sWYNf+0$H{-lS?iI2l7!Exu7am z`#El&R;RIV6DM*zcz1Mn&wnp+K1IpVa&k}9{QhYJ_H?#0N~dGs0rYRWO-wWT_oLbH zDfI4fOZM`c)PlNblYp+|B0RgetI$o&%Gj8~hi)P5-`SYQ4U? z@;`a=xZG`5?W_PdtjB4s*sA~Ov%dYfrSWv@^*Eh--;ex#zcw(P6Y!|h`fpS8?w|gb zf5*S_CwDt~F@AGy);rzLbNKxe_uskVLKo-q&&+y?2zYiLJAd4C(<6NJcp=NtU_4%K zbcx}2yqzzF;z{m?|JytlFaFFMo@Y*P|H(KCr@z@{hPVCqL}A6a0RW&o^ar4FXx;Xq z!RAo^gg~JgxX=#+k~&6r>rf0fvX#W>s;#0W_9h(ZHcCI2Xrap9o*_h>D9`uv4uRW6 zg^=MIo9RTU-3w}ti92eS(Z=f0hQn2;6V>DO>aY^b+;I!8z2hX&Mja7|NiE5w4!EBX63_tP2E{2}I2 zvkx8txp=X(tVHj6?Mq3DAeFJy|maRAzi zOBsz$l@v0JdYnAbN#M^y%7yIEANyS0 zEuA8+R-ovoPRzegUyJ{;8curOddKl}9B|%xyl_)t{LDeLqsGLanIIyFjGqSD@kN;-3s3QP%@iVT&y6;fHQdvCD@=GGRwQ_IZjr zO8N;dTc71K}}WcZK5J#S=`P|}-6g)0Km}Ro!P+P|GGfgby#-7P7nZ_b;cs@J;FT?ko%SJj zS$C5MSqsaRnoXQd)zcY%M>csDUGAhJh+u+|%Gfp&gu5K2^^NfFuuwDA(Y-h>SEE!FuEK^*?nzSx4isO#kYe{Mgp&I)NKx)^C0V3 zIm(~sylXw%iOwUQ@v6P0n{YeX7JerW)V$YGpaYY#8*w(_SSiDpk*b~BynxCn8qE3< zGXs?y&q>!_lbz9x%Hyf{f9y{jm1yMq-lZy&hXIbJ3&gz9pp{#ik^dy7|E7S`@0+3C z-{&tQa-UfJ zSQ-2+5D)mg{;+@LN4|?-(qH%On63948q4B-*I0~>3k@dwsklzfmqYdY$`3sbIavOP z+v$Kq)bPLCdrDz;H1GA-UiZVT4;<{x2S3I`NKZaD1>B=PXK7CZt+z|-af`0!l;uNE zZ~ce>4cB=DrPtEq-gq5Pfa6TnE^30oJ_D)IS^ZD(Dz9Y?**L2CD*uO}Y|=%)@xy7O z0G@!&E3$S97B4k#L5f&;-zU3891}cV-c0O5pMkFRd;O~RpEQTVSdH5yrQ#jg)q@)g z_AABWmleGUr&KksJq;t)cp*cVlq1~sjrG?_F33}Co9tjlE&*`Ql<4*d&PWF|M^S+$ zYG`J3R%Ebul2?`>)X`L7-}=J6fxX`(e~|&c$SyO;4WtjY0?$iqQ3qVTndT%SXr%Vzc@Yh@@?+OU31< zfxC^o-74vuY1*cMJcoD>B-t>VkCF)seTUsNWNv*Z7*Vpfp^VAFn`RJXt7|T*pYJZ! zD+x6}sjwFAqo5zLH3a)^v16pZP2&!Prujg42~#VbtD_|Nh@j5jZS{67>ls_Aq#MOFDc5$!A`d+B&IMvJ=v9^!>qlHCA8ueMn~<5T)j98b9n=d( z+!O(X7mkjvHqnQLdZsl@S;=T$y`g^(-qwIc`t&pCv|Nd{x}G&|pUxMbI-mH75Q_v* z{-)r;AF`SrV87W#>Yv@?4(EdWU!@r@-d>jLz-gX$lcX8NJP}L81n8YPIDJ9BMD+z+ zMcO&G1-$jXf4h)l)a>|hV>iWai0(Kj&A}$r5{9Q^-~)UV{v`tTN0AJq0m1(=+K>x= zq>sW%=K*VXqE5dywUA4G(-GSM|K5L@^>)+Qt)2<)J?(ZFQv6Wk!n3 zWM;H4zRhpzFVC!=qU*9Qi4Mo>rh!O=X>n~Glk?Pi->huod`D`Pl9ROOy8gD5q(b7Y zMlm`CtspIni9?+>|5&7CeYVtUswKwKv8wc#ww)@j&S6MX%JQo2USqdJkAU(4`I-o{ z&Mg$S`k8ybUl%UdL^&7KKgSgX$PlGSgn}J{*hBs~szC#xk4Y}CY8oH#{u~br_GM*- z3=vV-Pd61lx`iwL3fdV<#n$Sm4AK^UPY6kz$YlYp;P(@#VcQkf*Cc2Y|2U+;LGtpP z_=nLZA|Sz#23w{H_&r{)`Y;fW`?rhfH2(H%K{l(%>*XXpl^$F0*x?93aVl1?uJ2$q zeu2f7$epInD{2mB;Vd`WEPn9NOVvMMNMdV)-aTi2Rp;$`jNm4g= z6F-tz^r&g`JWMV&{FClGcuGuNfvUr{YC~~G;FU0OLfs_!M2D7ezT>7BDAm`|CH^#q zsQDXc7;w0-x=RPU$WI)wmNbLF!|CP|5ymy1ignt#aZVEY8Y1et@$14A(?4@JNmX)F zp1%}uoQT+2vp5FSlgw=9*6G*dEaNgSgswJ5iCzW-b#^4*@=PXm zf{xA=<}Wn*9WT7Z=LQti7CD(XctLY@s~B1x%CDtwi-0e{S$b7GGF@& zr0PD(ce>lgl;9z7!O?}yQH!A1`{b;kQF`!>{o-7t!H=(bj5)D8Xm&|W^E@2|LH#GC zuMY&aEwdwV+t|YZbaDt0+hX|4f2Wv6uIL+9*)P$#o+RdS=uAaQU)X|+T_4$huv?W8 z6!P`e(+(o=UABH=d)SW$ji1H2rsR1_g7p>P6Vj&zC1_b4vx2%*{{^Q!-&jXLUAseE zVj@&8qk%aeO4~EK*zHIm>;Fl)-8zxh9pUKE&93!Pl}JU_J*ZxO@sP|!t=-&v0^{Fz zx7wjDbmKIsC<1yYRsERu0ZXmc{=(ay`;6|G#qq39U_FVX}7tNZTzGvoKIU>tZiH1t)}rpdG&yT?M;1IyaDV`e-0~N&mCu+ zGaFJ^7JMj4u)cF-DiBL2&Hyv!l_wB_qF$pk7V6Rn7@bi2T+z|L=&Dj>YorcfRcrO+{(WbTA2tQMGG#SFnYG_g*Fudoh|> z2{y8GRc>yfVpOAZd~CA|>Itv62}Iu8Mc}@vp$%o)mU*HeyLIQ)x{nV^@MoD6^kPlAUlY5`BM%AdJ;6 zF~*zDlf8|UMm%G6)#pn0(YcX=pBjpw=}|sSl=l4H4h7jV63*@3<@m?b>f8u9*@-*b zn3ZBD0$=CMqV@u*D4E<3yRydrt$u>_Pk1=(W=)>k$A}{zsH^#_z3xcl`H~J%T{g;g zfq&08@UKHGA^R?1%QVLk^ShGEcWam$zCPy?OV?d}ax`#^;PN+w&U+UcLY2j6Q?)Zv zFv|4lmLUDRO4bXJpogJzi{IXfic8Z9*0=yICr-Gq%Gaw&G>7eQh7yuDIcw%(5>j@_ zoRD8L@$SO6A?z54akWLbnp;ylcHh=eS^0?4W}8fr8?# z%IGb8b6h&9C>s$Jgnu@t=y}}eFRB3%i?o9W4z^Bu2xLGrRrWD9iGM&eqV>FX5k=qt z3`ZOv+sT!iI?p0HNEY$Wb5KMg&=w^>Nq@O-Rbm)tm;TvNlYxA5P^W`kK6P3{*K1y0_L$&i{W$2gBDb!b)lKSPN?uF7!mJ#5B~Y^ zi|g`8I(=8R41GbHMQ%;kH?lXo7uAfmrF?_g`M_B|?6la5Z#LOylsl=TnBy^>w5}Kg zpyRY~<_rO5EDeuBuc@I=eHjaF2tgR3OY^d~uXtSyn*;mdT&F-UD#- zO}^TuNeXp)>DWgR)`h7q|)drUj`VHxstjuIBwr7tZ*ge%Dk!{$}k-!)U3P_@exmV(yU5wcV2y&f0Z6dnN-?7Cz2YWg?540!H3i8# zu}t#{M7Rl1SY~R&41a&-e-x&qxzHgiHZ@3@&w;l$2DotX;doez5qBG7Azl#AgU&WU zAk8Ol3}A~m5w!{XN9n&(yKKm}+l*xdIz}fkZRAv_c;%9J*s+xT5{m9fku9lZFyh9k zHP9OpTqsz&nY8A?vaOd?*qX%JUcaE8QPFSGK2>Te7{wsIq4DC2RAF`;@0raf*%KVXn^q!IxbpvYklRo1zibWJ{5y zcDeUrbYJ4<#SdYc{_5~Ee|Qu4{b5Oeit0pWuOD%Q8{La@Vzf$a1RBp8^NT8rDD8bd zi9`sOU{o@ODt6XT2g;vwBeFbY?Htb4eb(iI$5rz3a|cTK*U_E@)%Nv`C4wibdZqM3 zthiEw%xx10QKCp25tyq&U&bf6E_6C8pRa|lZ=p0sVcPU9;FEK`t0{|8=_$&|j{)DF zq%4>YO7(Tr2R~n1bUc6b69Yun2W1dH>X@mI4Tk&tl}$(&r{z?USh2NG;H1l;eFjBX z*BcU{Y$1|r=vpZ~ZY;=s|)D?7XYwSYJOHw*a9eztLS^0?(}59(!dwDA6w z3&+2Z4_`tAOsrRXB5YPIGv1`za0)4Es8xdNA`5|9U4vnGGB4A)?TBK!AlVu!+-mXm z)M-LGcCs4fYJ8ZJKS1R7<&LwzO_BmAcEw2VTS{B~+aiQ%PIaYyTD2Do;ylMUK7H6j zbu?6p#Ig-BA%<>E4%d5Q-wU~aNGvh1+}1Giko8oseqBwE$cncOLUXetscrnL+XU={ zfh%RDt-uS!4tj%)*%}zH*TVq?MD>oRMl9D9aeWSgqQ(8|>H-2T&e%j~&jy$Jb=;$Ja-ZCVceg)^f&hG7RWL|8$s z>-ik$>LX{}<0l;lPHYnP2}bn@1i(;jH^=Q{iB3>K=#DfTdY^_P)=TFKRaUV%YnT{R zqR{U^d;YQ3>uhqNzAr<&^gJ5RGNp$uydc0-oTt7%-?ji7K6O-j=w=F&VuGib$TZG! zyV-pToHq$FmS39%Kpd3j<>03Ll;-vG+kk+F z4evJkiDcs_Jgn5-|Dq#RxXoK}F!4jWGameSxu5z{YY1MC)~c~41AA`W-C?bUD~SfB z2@Dq5A=^}4B{>Zz_E}gcJ(RtL-)3Zw+jTb03(}4iI7HS}s^P{36Ekb3Q@vEh&EGH4 z-XI3EbWx{NSw7Q|PLhy!cTV)(-Pq+Wapjd=kq-NgV;m_wY>NFvMu~GilR3?X3WH!4 zgmfpH(ZZYpV*bd2+F|>{x?+@Hjzorz763x95Rxx2eMWfxYy`d}TfoiDJ07a{utx z%+9cOCpbvfe0^ce!dQnAWv5c=>Ii4JmW`znbU8!ABGq?p|KlZ#z+>rF5smFarWm@9 zg8y^R{Fh;s9M;*bkl-`9cW@NpF5O-og3`>1WOF|^=Uhy<$qP-xaHe3B`0!0gEQ2qW zJ;N-DR%Mi2xcfRZ-=V5Lfmo&eZM0O16CcEznQHT6bSgr9%Y` zAFwDe`q`sDx!(bL@IQOW+2Qy7scQmzg8II-(IZgXY@$b{D@1ZrvC?MfZD9r5`G@`k zbIVw5g*4gh&qF^}bRVbKA>IoIydmxGvX-F$L-ID4DzyCc{cH0nP?LEqFj53;cR;*O z`VG-OC8I8aZpMkH^?TBKwb(oPB^PAwlLU_@bb`}mBOLT16@QKuaZ3Ptm1yubW#=C- zn&#h8{1kP3ij7Cqfph%pKzaM>eF!~D~(CqsWci%UjNZ%Zw4w*M~$Je568=qgLA zz4{(*P?QALin+d@VP~cSiR-bg%7Uj*0-$H)`IB0#6R9w!6zmk6U|L7SH)!@Znsyww zeG{R4DEDNv_Jzx16Iz%CuebyQDc`l_=C{~OW|;;+uS>|5mpQvqw*?!oc(#NFg@T>n zNTZ&P5=9`8xQql_`EA;R;$(?bjB4G4$(%kV0VQazKK%QF2cK5EIWBlL5BQ0A!UdA_ zO7#Z98W3UIBez{M5YHSlo2)WYXP;lH1X7@^E796}If4bRgybVcM9UOQ72grN)bcIt z@xH>=PKUl3$6ibnt%N>4Y^U!oMJ@+GU4y?vEoD7w7$+1jhYeu&8Rm>*4DOO z{Y_3=*>NoW@pS~zaMr`?MEk+wD%QxTOhX;R>mfa}n`{J;HJq}(!uQA?7K+A!iAdc& zMf7nfi-@XcwNap0*vWu1j-gKJAEC{#yGNhVBih&ecT;2E2Q3FTM{oj~c^{CbzAu2z z|B`G3he>6Xg$phydJKcA?{AL4i+lT(isgqCw1(92D_ZzrG@!$)yBH$?J=+27y z1zYy^0aU?>A6FzziuIlgo$l7DTdfy#*B?zTkYUd>B2-+RX3W+a`(>k#0HvPJ0ayCj znSW8O{=C;rcS?5zMc7N+?Ih5tJe}+geiwLG|9kf^C(lM7^)HJ3m`~Eeq`3y#t?9;8 zR*AJtfN;t0WdwzyzNi)!<}LG{ngfTr1B@*BeIv7LD5ovZXhz*Nfmb-xmR>*->hW4d z^p~8Y;P*eYHis1YmhZMdGL06WX=veeNSyix3$eo|BSVPIk+7Baf6&0$Ifd4U0d3s+ z^7_o^Zmp=f4evOw1V;p4cjG92X#c69b(FR{e^F|Ki%S_c(A>lt>P#4=)`}Df98a3n zCC&jV+o_sN+Q!syb0Z5OGT?T-Pt>5smGBM|=K>xZRHFDO8yLqB#2$H$uW!PFD}M@j zY)-J@FaJ=Wa)H7o><4j3@Y*GYhd!j71e>|fQKdcj5c#6EKFyyf4Gyr?vEw6)DhJf} zpHdyuF525gx)svkB(47%0l;_c<$x~hYwARhzK#*Z3O%6E-^u3%jH$Nzg9tQhDzD?m zRPW}T;gIz* zTRQEJiWV~#X7~Q|w)paZl`S{GdcH_c8F zZ_7v#GxmPY2SP??;~{H3-NlS`@u;n@U7k(mVd)G_S;f+fKorH6c&>GIv{K4zZB$V2 zVrThMLTX1|Pd&{Jp)y0AFsSM4B0Wz(^5j=$waBk(gg62qq&GC1B%4ezkDtc)n zty^UQX4S>@a%IQ7p|8qAz`o)@tq)(|Szlgf5#iNVr&Dw)bBKne`b541{eu*_EgJfQ zlnVxi_@1pICFs%%#_$qRzYHCXhnq{~@c`{m4iFu-6;K znd->S8jujfW-C2p&g%d`n!mn;h>C58blS5q_EwE64Ch3>eJw z_yd~{`CFFn2td2FBiz(|IH(PaqcXVm7-{xAoI^N3=_{zav9nND5XH1CT5v>$ zHbngz&Fh<@0uUEb@hk^*K(gs&C#V@uF&Vz@damB$6JmlZ^fFWU|Y<(axL=d#{O zcPzWgw;T02;byK)QY6RCWK$dE!OD^##|p-rR3bf+^+KGR5+`eF<=uwK3N8j=$XhPbuWZjd6Y*TD+C1)ma-)oq%2u)4TIZp zFvrJ2y7RVmB?mGAo{yfWb> zTM(1$v0QD|d$icxdbM~gT74IhW^{IzYWc#@sAFP8jFTv&Y`&%K*6gpNmHpZ0wKArV zrh^0|C|Ha+upbWHmZxNU17+j28bE}W^TUCjLQ(hDhUd&#+QeI_{BtZrb2nWE(z!2O z9MRqEafecUoV%!>7v3S|E#n-vOU7cHRbklHo{Zx`8B?(x6%{#K7>BOE8AB9@{RDYn zwQO2*FZvv}{rjEBSSoYy_tb67_9b$Y+cu`)=S6s_NkBkmtHpeW*7qrZa5mVV_va0< zGAoXo!aD~Nt3$sv_5zha5Hd&*fo?S*0y2&NI{GkAU)zgt)7X^(m5!tF3MbjLE85AL zXltvu1$zn_5lNO<{m#l5*Gpw_cy5k#Yu!M;BGFO7@+hmTzRXkkWbC>><)e-{&=0^J zT&{6g{h?$Rl|4i&o;mG@10xA-ERI1$)_RJ$kdDGWWTq#8`k>&2`6^cT*rHS+1!M!@$i?W6iw$Lq94y%UHyNK>g~BBkFJQ=2i-%jXEly(bmZEicQ6SV`HX4^p|B4ie;G90|opP^A59!6Nvz< zURw<@2}5as4TihCaRT>iML~)`1&uS-RLH-Bjs${RU!u;e$x=}Z9iL^Sf&S`uZNDF`!{~lCtVO0c#T>NOnCILa z^}T+_ukW*&;!WG=^dxrQ@=po5AUzh+u_)y3oHml0PyyPL3&)aG{4ho!Dwav)SH#_A zqY*GK`6o>Ow0)4=Cw@zPu*SVs=;clc=VFHQG{Jik5)DiDm4>O`3o?Bw+uSMxF|Ko& zZ_=nyC{A0K0)a2POI>)E`r9(QvFfkuI}aC%X1p#lvo^3FWxcf3gx%bLvZS-N+ZBh5D?APQp7 zz5Zt?6B*BTFUnqa(hE%|Y9mv69YW6B+0W}1KEU8N&e)Rf+GvU`#QEEafVZEXrY<&L zS^7Za-)iB;2uM_Fm+p^l9na3np5m z9s4vEjAE)n^SEB&ASHO@ACywg^3}5*;c_xmMfDp*&I08;UmbCzva#cpbA>n+3*3M< z9#EeJnGgSul;wJT`C%z1C3oUTunc4Sua|6)vG~s~7k#70{@cyhJ`O1-HPs_W#xc-# z!yf6FjwQYLj|+XGL`rs&Ez^lJ3b$8Q?XK0Z6Iq#- z+8t+MQ!IzBWX&HN6K?DEX^3rx#%Jli`mx3k&WKl?6+Bd;PHvTN4Bw(AGAgRL|5fEY z<11ZCn%A8+S@|xF-VQg9eOsM0r=3m3X8MP4}esiA^tTKMgZdrR#=P;j4#1YUyBo9j!}+d3TF@jU=& z`=4-nmk;19S;dEjo(T~A!yCIGLW*rWueL0?NAH7fxnN%xDKl#T2O9-~ieL}j`qZL+ zzlWiJ?nUW9<3uPEM|mi7I(xsa*rt%=elorU1|k*u!ac{44Ev*XN5@~`Wv&jx9@+ox zaS~2~i*_mcE6l&L-1_l)sI?C`ZNh{ zbuJ3`qaK07)2U~2h|L;jkNA}C)W}fmlppoa2~3~whaT~E4~GhI&5OSvOamm`{l;7} zC4H6V*XFF98G+OXRN~bnP2pnW68-u0-w(8V>2%!@WUg(Ol>4e7lz^QKDwA-|#f*WA;_v+(dF0$xkiG zc6}7-EnaZfT*iD0f5J<5u_0(YM!9-68#ipQ_S_A~pDq8=Rv@N%@tFGfTrvZomPZt&0Cd* z7<8?d*`-yW-0|qT24)BoX22okGK;n@)lk1RAc60!1{7KIdR1FBLkq+l?@`p@a?1hd zOBuD3kl)qpQ%$f{H@P8Zm$%X!i!|OP7!zvlA9#Gen|?dR{s1&?A|M{7!Y*=xAES+z zZcGSOS9d$Ds?pHySiBv4TMHaT#*Re*QnjZEt9e6_d3$wW7)zF1d6c(^$6=xe*xo(F zn+!@z*c|RWSqcTRdq%0WFi3KIw%yl8&@AAv)Z#Fq*Tiaql%I<0OSpZ3ADFXo$XL{s zSWv%!Y%F;GR(To+_ezfm;NAQ=Kw0rgyg?H$4&*5L>dRC#Ps?u%|Jg zXcwxYq7nbjcs>3unxygJ)ok~PGLXJy)hMgs-B_qBX00!0^-5TcD;YOfp-I}$D;`97 zMWjKnJv?eJiW?zcq432v!v=n`YlGwZ(w@RKx0)82OVj*zsSGR`!`UkBT?Y8BM3Ma? zTvl23ufd08A6}_9tzos3WCX_Q?@>tf=kyJz zZnOy#%QsgOXZf`7V?09Gy1yyQ+Pz>V+=_-P!lO(D$O9)46IGIM{2C*5#G}3OkH?&} zQ@CEPlwd0ji(}HXHlk321U1?+yb(8)JI-=v^~xC`ME62*FF2UN`c0CbLg_o+qLpi1AG2DcTXkTDYWbkM_1YeGqnpx^Y|T5#D}qtXCn!k z1MZ1md#IEZme)o`rCFq%!xgo{8_QnADvQv*yKsFO%QA28nItN-(;cNqA{sLw=-3H% z?Je@yc-d}eScQVYx0$03m_h)QjsKDSf6HS`gfL2@Zjk z(wad_f!V@f!)Y%(IJ)xMqYFLx$@6}r`!&+*l*4RX z;ntf=rxR(yD{Kz0H#jzD@*hc!nBU7X;%=o(i9DEX?Jtn zp-Fi}%z&Npk5NrF9R!YZ`v5JR{JBtu=8W;m>rdUgwk}y`lj6cRyiLsupMUP)qA~r? z5d4MPem&(&4(QyjaPxl-06Fp>#8kbVR;yE%NQxCR< zif_5Tx=o~K&1J_nU7W1GHK1K9O8);&6feFki=Qr)y-`{HDf76&B<^$CO*+j{oq;rT zXNaRaeJ)=?(pvLny(i|4a!bHkHj!qp!H7ZYRZrjY^GNPHIR3lMa(dScof*ASe5Mze zKTD50q$LB3f@0qc#fOe-I%V~dt=4o0mL%2xf)@htNGZ#zg20J#wc(Vh4VN{FI_0CQ zJnUR*cBO!vCYW(RPAclfJF=Ct3O`89)1iaOQ{<*S{H^;!gXPsy5T4DDGbP8e138*y z{2j}GS2OkVxF-A8lt~K&{+5QnX%Z}%pGq~WCl^i>7_4LOWkQ?(QC*@J)2gq-=7)gt zMzr(n+Q;$!xuq#J+_0!vP2F}{`69h8JcR{cy^`|DofERO^in=eT$M$QDsh#8a#cAm zz9JrJPx+DZgmL=V;Q@ckx!?F9!iH*HLxd`J^d#=UD|;d%+3#QU_NTpOhYa^Q4dL3+ z`0p0?X!BJS;tts8yCT}!SrN^g398g(dF#PE=BvDu@0eU=8th&2cFi4}_De%M!LlqB z#tHOeY^Grhj=$K(CvDxYqZJe<$bW>wj+ke1VVhbXikivC$Y`7CcP9I62z$3J#CaTy zusSGzelWF)fbk{bG^KtIBfWfXDnn!Z29AdlRYkK9tI+J zr-W2XC9@yePLZ9Wp5U&_5-!@AvEAOn5~VfwIz!^hN{@;Xc=m76Of^>Z1Fe?N?afPS zY9YVCTizi6J=yjc>h1|4Q_e?XP^T!~4Bgl|NTqS}XoC17xSZg6B=sYEk^9COmTaIL zJ4qjbIqsp55_~{)0Drb=g3tU=pGSWrmc6Ay+2IqTzYMuVX?J1xL|t>nt$Wd{s~Hl4 zQ@tW~0{h*Ib+vIu7a<4-j?T1@_V|?YGSy%y3G+XerC%%k-{}$3|9=ao)bH^2@@h%+ zbk-_&^_{(8z{6!{*c;hilHKLDWxra4L+4NXrvoeTJJh-8=<(0LpZ@1P97`)mdzcdz z-=1O6w+*TdR^BBrU9E&=Es{cfG5Yn(7f^N1Jq3IN1o-z0N{JY7 z{!-AmPmzd6^puSXq25Kjuvyg9SaMl*TJ zmB{9$sbcO@XR;7s^r&V4H!vLHN8D0}#^lH$Y%_-ZrC_^-n*I9{GeIGW@kM0125#A} z3b}I6wIrY0n@mk$eS@e&eNn^5zaKP6T)?rpjy>O=2LhSjUNVHU&(j(O#^&`Ds5-Up zim2xZ$?VrLMX>qyWmkjkZ2oBeoefQ~ot1B)KJQ}=cXIjqk8rIaI!{qgnwd%cgsI!W zX4VNUNrjE1PTI!bh~M24eg5}wGIS_z1iG8-R;hgMcd!yXO!4_X|KW4MNr3pw z&Bfk!z%Ai$CKYPBWd1hQ=lim-kMZ3YW_C@fC5qiy2D)Xb*udn_M(kk+Rfm)4vtT15@AFo%z?IQyy!-h8ir#g(nJ@Q8n@=~x8s4~K zmkWih_g7|q6oQ=-^&fpu33>v50*>Z91f8A^68o-7M;zu>u#Q{rmcF-sNop;LX?1Au z4<0lT`jQ3Wd*d-#UFfrm2&HUn<*D zc0U8Ve%2J|(fFPulwfvg=5Pb&a!9)#og34QTJV`zdx@OYgps+=*^t|OOqc!*H?Bm2 z72Il&leAguC+gI`l#U!JIQPUzv`ID;B`Z2go6lqu zeB=^XaE6?bpQ+BsJ&?0?=O4x-?S(<=;Te>*3^Q?138+S(JnhmLGH|q5f}+)9!igB2 z(v+jHVo=_4LpM|x>Tjnw&){E&;GLrNqN=}cbKsw}!N~T^lp)7wOmJdN-**@t%Q|qV z$Jz$95QiIv%_|hY5bp*`@Zs96#(HD7_tDT&pnr)<)m#VyyIRCY;g@snRo<(aQ>hn4 z2JtK$=7hD#EIRNXibU4o&*Q}sn_BkTWB*&%9+}cKP7Ft0I2gOz;P*OFUcLiMgv|_m zuyac!6V#@eN?j%&_l$TC2gcO*Wy=D^oWCo1WJJXgxU4@ctcJ&Znx{DqCAUcGfao_sIFskMx7H9l}6U{-(S| z7T6Rs4-O)J78@Eg-^1nfH!Dc7qi}{hGwHKsS~EwXdO~hbA>A+2s6ir?z=oa)T89Oy zJ~>6hjW{)2?%1|Cpd$|c*9oD@(6oA2nw^Q;jLK?{Mwj+f0vA*O4Mlt$>;Cpt-jL=l zjq;N5&6>nWFs#7_6Zjq9nA08 z?ahtunx0WU!pqrExleA9Z7l(z2VfIKql3T!L& zmg1DLx9=u03P}aQQA0cE38O~k)bih}O^^xR8Sr=CAHgg-rnsSA+6@?8J?aDlir6kF z08y5?;8pf*Ss(t)r*zcB4Pod2D@J9;W!Z$ZltdMe;0?WkbPbrMM(e}1ka{7$j%xhM z2(dC|TBd@>23CTT$sEpmmmehR7cnt^k^sq@Y6>l3Ns&YpF`9BpRLr(^sq&$V^nO?4 z+$MGzaA8*cm!{=q;T9z!-og`vaq7}!LYEo|G*u~_N|Mf=h*X_d`aJs@Wf9+kjICVH z$Pq^)=(5(6oe%V8g4IXx3_`cJ z@4$6_TzW@`w2hdlMjhvvT&@XI;lNx?u82ZpenJFG52KTXQ9Wb%|AN1+cv~(M21x?p zff@RzbA-z9qtqEhbttc0$yO`B&$qpCkA0z3s23wh_0|cHBW68%VOu-w6j#y^vQRC4 z*s?a9>bNs?>yy^vNPePE6*SSF@JJaWG?uV>|IF)cr+0J-@y}nm-f=AIo-C#u$~-A) zfDatyNP3k!L!cKmUA>C6Nh)=)atcmqa#1~Jq@sIx*>hpm*VC)@%@|ZO#Y_1E!x-=x z_klZEX6P&A!tT0`dFr^|yhzYDSoOM~8cB2v$QRV^Eo$So;Rv4M;|ea|Op|a0DBeNH zNxJl1aQ82aKNpmo?i?)d5khjon<1;ZPzEJlDTKkX_V06}<8Rr@@ga&6u5uT@`t9A^ z8p<^X;he+|^A zJA;j~02QXA0@S-!VZ%G_Jrvo6X zG#l>8auoZi=Tf{Ku-9&K_^7+CZhmc&oV}4r1DP!7tGB!m=(j7F*-)w~4BTXe>=M2@-m1r#5gUUlUdvCR08V zbI^+i$-Yt)$=D3fWmmGzMqdDLIvnn8@kN1(46_MDH>~LR>q|HJYeM zRAisOdv;PvBJpqxW@<(?FuU;C(e#OT#wqnDJ|Uw*g;bZDMjn zE#uw*)DaZvL4Wt0Pr|$SRsfHI0kyEDWoA)lGtGw;MI6o>r_p07#lc+Ic<~n}s5^4R zgeFe;MBEA99sl4P=hE#-ECR#K%@F>IH|=O4@5`ph79|qb8yb4JDvc~b)V9tH*k?PS(IhXuH9%gT>^{fSLN|IJ3;tNm$J!9+O z4u<7iGeL@0;g7*wQv?M^bxc5IsA^Nca|4Fe9s$IPk1qbHwYC}ZL>;QF-6Ru&_x1)| z%sLG`7Jd;vtXSUYQrlWPi8(fxpz3Cof)CL6pSmgFfe3M5&o?z)>1u7`1zJ{6{(Kzz zSh?d;55*h{%_R!m>fzn5M50bT^n z>0#`XXEq$1t;(mCSBT}{%JH!Yvs6u4egZbN@D?Pq-v6*p`_acc^wUZ_&9hS`7m= z$)~Il#OLM1nz^pW?Gf|D>uLN?+<eE5_9NSZB$Ap6vR)?a%1*lj{cw1?huy2SL4DU9m;{r<9b&xRx6 z#7F)3B;?g@pVgOap#Pz!;}+WnAL_v3TGroXQeH9#SE{9d==47C7_6Qi%jv07(6zN5 zuKN-7%p z-7;44ZHpyL{b--&vM(GvIUSdJ7Ehy$>YG%(!*>Qi`=tvs+owOPQ?c4IPxssBaF;cA zM;+q@!`4gdyh2rZ&gST|fJ4W4(^W|}#362_*;26D6AzJNOf<3Tz}=km=ghwq%9;NSy0bk^!bfJhF&Kzo#s&`wqon!m90Q5R|i`YuGay zET582KQ@&4~p#_zjF zG`0$<@R5|^RTr?y2}E#^gMu3b5?1*VDjhT_QUGO<+I%2H3jF1>++UEFe?1XBm@LAN zXxHgke&LpvOP>(5V5u?hE;~8o5hq{=y8KWf=943TN+9!i-SO z7_+OlVFFP@mBjZZVTCWm=N=@g9p8zeF+rqX7EQeFS5iDt^fg_EmWvWfq*MyS%Q#JU zWp}pE>;;u0j*>!IM%kh%FN0A4OBMtZ*ki~}Rp7<@zR|jMws-G-QF45*dq_h{I!qt!PxQvjK-4VNxWl7F`I2u38;_XDg zEevWprOoc(QLt^!d`D!L-25ZCd6Rdh_9Q3#bl|jYZ)2i%1A7ub3mC^GmgiJX8?4V< zbo2GA3ccql7+qfHz`xEn)>@$yV}H*$!(4#^ttAC-w)=@3EMIMW?ILZL$lR0Dbu)X#qM%(+u!PC|&4;wh$LPAV{w%9o94%MgacAL@zZVHYsd2=UydG}2f}^c@UK+FLXM>TY z<7G+`t|!5G7Zy2wGe4UuJw_2myLq^rgV6OBa4uI)OOiaJsdMFAF_w^1XDMdg{|-rE zRT=UoAT<01z+oW_wX5`G_^PSp#cmsUA@8lg~2ly^)7vhgZG6|

A43bpSOI-aaW_KC-9eMQY^Y2d%vTN&OSo1Q7{{L*8 z0r~e24@Xperk-;@UHP?^Y8M{2q4#|?S;6jCn<%@WT62RzR3-=`FxJkreCiim>qKVL zz~pKQ8BJID{XE=VIWAo|*j}QU9!A^ESIqs!NehSUKaTyf`~OD!m+)Sm=i)^U4nQ|B zxP*A^Q(!%iUpTT^Scs-jmj2z?$uS%7(3!ui9Lt+WVsyOe2FYTHXnfud=t}jMD8a zj_p*Ncb4F!qo`A?n7xk*NW%N>+%B_v=VAxQNxb3|e9XDwqAO*}!+yn4&BcZiF?voG zPNI*=R~f9U2!+Ni@2=w+8Dn!KF^P5Xt)ZGXJFH(Bq@|56LR+zQ4g2URCgIR4B4KbI z667i6$!d0h*)ggZTmH+PU3!^NU1yK@UU~;h$)llqVkP(wrv;K dict: + """Average a list of checkpoints. + The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. + + Args: + filenames: + Filenames of the checkpoints to be averaged. We assume all + checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. + Returns: + Return a dict (i.e., state_dict) which is the average of all + model state dicts contained in the checkpoints. + """ + n = len(filenames) + + if "model" in torch.load(filenames[0], map_location=device): + avg = torch.load(filenames[0], map_location=device)["model"] + else: + avg = torch.load(filenames[0], map_location=device) + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + + for i in range(1, n): + if "model" in torch.load(filenames[i], map_location=device): + state_dict = torch.load(filenames[i], map_location=device)["model"] + else: + state_dict = torch.load(filenames[i], map_location=device) + for k in uniqued_names: + avg[k] += state_dict[k] + + for k in uniqued_names: + if avg[k].is_floating_point(): + avg[k] /= n + else: + avg[k] //= n + + return avg + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--llm-path-or-name", + type=str, + default="/workspace/asr/Qwen1.5-0.5B-Chat", + help="Path or name of the large language model.", + ) + + parser.add_argument( + "--speech-encoder-path-or-name", + type=str, + default="whisper-large-v2", + help="Path or name of the speech encoder.", + ) + + parser.add_argument( + "--encoder-projector-ds-rate", + type=int, + default=8, + help="Downsample rate for the encoder projector.", + ) + + parser.add_argument( + "--use-flash-attn", + type=str2bool, + default=True, + help="Whether to use flash attention.", + ) + + parser.add_argument( + "--use-lora", + type=str2bool, + default=True, + help="Whether to use lora fine-tuned llm checkpoint.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=-1, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="beam-search", + help="""Decoding method. + Supported values are: + - beam-search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=1, + help="beam size for beam search decoding", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--remove-whisper-encoder-input-length-restriction", + type=str2bool, + default=True, + help="replace whisper encoder forward method to remove input length restriction", + ) + + parser.add_argument( + "--dataset", + type=str, + default="aishell", + choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"], + help="The dataset to decode", + ) + + add_model_arguments(parser) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + tokenizer: AutoTokenizer, + batch: dict, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: "beam-search" + - value: A list of lists. Each sublist is a list of token IDs. + Args: + params: + It is returned by :func:`get_params`. + model: + The neural model. + batch: + It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. + Returns: + Return a dict, whose key may be "beam-search". + """ + + def preprocess( + messages, + tokenizer: transformers.PreTrainedTokenizer, + max_len: int = 128, + ) -> Dict: + """Preprocesses the data for supervised fine-tuning.""" + texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + for i, msg in enumerate(messages): + texts.append( + tokenizer.apply_chat_template( + msg, + tokenize=True, + add_generation_prompt=False, + chat_template=TEMPLATE, + padding="longest", + max_length=max_len, + truncation=True, + ) + ) + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] + else: + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] + + input_ids = torch.tensor(texts, dtype=torch.int) + + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + return input_ids, attention_mask + + dtype = torch.float32 + device = model.llm.device + + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device, dtype=dtype).transpose(1, 2) + if not params.remove_whisper_encoder_input_length_restriction: + T = 3000 + if feature.shape[2] < T: + feature = torch.cat( + [ + feature, + torch.zeros( + feature.shape[0], feature.shape[1], T - feature.shape[2] + ).to(device, dtype=dtype), + ], + 2, + ) + + supervisions = batch["supervisions"] + feature_len = supervisions["num_frames"] + feature_len = feature_len.to(device, dtype=dtype) + + messages = [ + [ + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, + {"role": "assistant", "content": ""}, + ] + ] * len(feature) + + input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128) + + generated_ids = model.decode( + feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) + ) + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + return {"beam-search": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + tokenizer: AutoTokenizer, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + The dataloader. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "beam-search". + """ + + def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + if normalize == "none": + return text + elif normalize == "m2met": + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + batch=batch, + tokenizer=tokenizer, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_text = normalize_text_alimeeting(ref_text) + ref_words = ref_text.split() + print(f"ref: {ref_text}") + print(f"hyp: {''.join(hyp_words)}") + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + setup_logger( + f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + ) + + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + + logging.info(f"device: {device}") + + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + + whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") + speech_encoder = whisper_model.encoder + speech_encoder_dim = whisper_model.dims.n_audio_state + tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) + + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + # torch_dtype=torch.bfloat16 FIX ME + torch_dtype = torch.float16 + tokenizer.padding_side = "left" + + else: + attn_implementation = "eager" + torch_dtype = torch.float16 + tokenizer.padding_side = "right" + + llm = AutoModelForCausalLM.from_pretrained( + params.llm_path_or_name, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype, + ) + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], + task_type="CAUSAL_LM", + ) + llm = get_peft_model(llm, lora_config) + llm.print_trainable_parameters() + + special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} + tokenizer.add_special_tokens(special_tokens_dict) + llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") + llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") + llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( + DEFAULT_SPEECH_TOKEN + ) + + encoder_projector = EncoderProjector( + speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate + ) + + model = SPEECH_LLM( + speech_encoder, + llm, + encoder_projector, + ) + + if params.avg > 1: + start = params.epoch - params.avg + 1 + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + assert "model" not in checkpoint + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + avg_checkpoint = average_checkpoints(filenames) + model.load_state_dict(avg_checkpoint, strict=False) + + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(avg_checkpoint, filename) + else: + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + model.load_state_dict(checkpoint, strict=False) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + def remove_long_utt(c: Cut): + # Keep only utterances with duration in 30 seconds + # + if c.duration > 30.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + return True + + if params.dataset == "aishell": + test_sets_cuts = multi_dataset.aishell_test_cuts() + elif params.dataset == "speechio": + test_sets_cuts = multi_dataset.speechio_test_cuts() + elif params.dataset == "wenetspeech_test_meeting": + test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts() + else: + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dls = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + tokenizer=tokenizer, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json new file mode 100644 index 000000000..730937a21 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json @@ -0,0 +1,38 @@ +{ + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 100, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 0.01 + }, + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-4, + "warmup_num_steps": 100 + } + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": 5, + "steps_per_print": 50, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": false +} diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py new file mode 100644 index 000000000..829ef4e2d --- /dev/null +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -0,0 +1,285 @@ +import torch +from torch import nn +from transformers.trainer_pt_utils import LabelSmoother + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +class EncoderProjector(nn.Module): + """ + The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model. + Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py. + Args: + encoder_dim (:obj:`int`): The dimension of the encoder outputs. + llm_dim (:obj:`int`): The dimension of the language model. + downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use. + """ + + def __init__(self, encoder_dim, llm_dim, downsample_rate=5): + super().__init__() + self.downsample_rate = downsample_rate + self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(llm_dim, llm_dim) + + def forward(self, x): + + batch_size, seq_len, feat_dim = x.size() + num_frames_to_discard = seq_len % self.downsample_rate + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view( + batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate + ) + + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +class SPEECH_LLM(nn.Module): + """ + The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector. + The encoder is used to extract speech features from the input speech signal. + The encoder projector is used to project the encoder outputs to the same dimension as the language model. + The language model is used to generate the text from the speech features. + Args: + encoder (:obj:`nn.Module`): The encoder module. + llm (:obj:`nn.Module`): The language model module. + encoder_projector (:obj:`nn.Module`): The encoder projector module. + """ + + def __init__( + self, + encoder: nn.Module, + llm: nn.Module, + encoder_projector: nn.Module, + ): + super().__init__() + self.encoder = encoder + self.llm = llm + self.encoder_projector = encoder_projector + + def _merge_input_ids_with_speech_features( + self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None + ): + """ + Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens + with the speech features and padding the input_ids to the maximum length of the speech features. + Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277. + Args: + speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids. + inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids. + input_ids (:obj:`torch.Tensor`): The input ids to merge. + attention_mask (:obj:`torch.Tensor`): The attention mask to merge. + labels (:obj:`torch.Tensor`, `optional`): The labels to merge. + Returns: + :obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids. + """ + num_speechs, speech_len, embed_dim = speech_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum( + input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id) + ) + # 1. Create a mask to know where special speech tokens are + special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id + num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = ( + num_special_speech_tokens.max() * (speech_len - 1) + ) + sequence_length + batch_indices, non_speech_indices = torch.where( + input_ids != self.llm.config.default_speech_token_id + ) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged speech-text sequence. + # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens. + # `torch.cumsum` computes how each speech token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = ( + torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1 + ) + nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_speech_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_speech_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, + max_embed_dim, + embed_dim, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + final_attention_mask = torch.zeros( + batch_size, + max_embed_dim, + dtype=attention_mask.dtype, + device=inputs_embeds.device, + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), + IGNORE_TOKEN_ID, + dtype=input_ids.dtype, + device=input_ids.device, + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_speech_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_speech_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ + batch_indices, non_speech_indices + ] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ + batch_indices, non_speech_indices + ] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[ + batch_indices, non_speech_indices + ] + + # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835) + speech_to_overwrite = torch.full( + (batch_size, max_embed_dim), + True, + dtype=torch.bool, + device=inputs_embeds.device, + ) + speech_to_overwrite[batch_indices, text_to_overwrite] = False + speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[ + :, None + ].to(target_device) + + if speech_to_overwrite.sum() != speech_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while" + f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[speech_to_overwrite] = ( + speech_features.contiguous().reshape(-1, embed_dim).to(target_device) + ) + final_attention_mask |= speech_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_( + (final_attention_mask == 0), 1 + ) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where( + input_ids == self.llm.config.pad_token_id + ) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + def forward( + self, + fbank: torch.Tensor = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + labels: torch.LongTensor = None, + ): + encoder_outs = self.encoder(fbank) + + speech_features = self.encoder_projector(encoder_outs) + + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + + ( + inputs_embeds, + attention_mask, + labels, + _, + ) = self._merge_input_ids_with_speech_features( + speech_features, inputs_embeds, input_ids, attention_mask, labels + ) + + model_outputs = self.llm( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels + ) + + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc = compute_accuracy( + preds.detach()[:, :-1], + labels.detach()[:, 1:], + ignore_label=IGNORE_TOKEN_ID, + ) + return model_outputs, acc + + def decode( + self, + fbank: torch.Tensor = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + **kwargs, + ): + + encoder_outs = self.encoder(fbank) + speech_features = self.encoder_projector(encoder_outs) + speech_features = speech_features.to(torch.float16) + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + ( + inputs_embeds, + attention_mask, + _, + position_ids, + ) = self._merge_input_ids_with_speech_features( + speech_features, inputs_embeds, input_ids, attention_mask + ) + generated_ids = self.llm.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=kwargs.get("max_new_tokens", 200), + num_beams=kwargs.get("num_beams", 1), + do_sample=kwargs.get("do_sample", False), + min_length=kwargs.get("min_length", 1), + top_p=kwargs.get("top_p", 1.0), + repetition_penalty=kwargs.get("repetition_penalty", 1.0), + length_penalty=kwargs.get("length_penalty", 1.0), + temperature=kwargs.get("temperature", 1.0), + bos_token_id=self.llm.config.bos_token_id, + eos_token_id=self.llm.config.eos_token_id, + pad_token_id=self.llm.config.pad_token_id, + ) + + return generated_ids + + +def compute_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py + Args: + pad_outputs (LongTensor): Prediction tensors (B, Lmax). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return numerator.float() / denominator.float() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py new file mode 100644 index 000000000..eae967500 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -0,0 +1,338 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import logging +import re +from pathlib import Path +from typing import Dict, List + +import lhotse +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, fbank_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - aishell_cuts_train.jsonl.gz + - aishell2_cuts_train.jsonl.gz + - aishell4_cuts_train_L.jsonl.gz + - aishell4_cuts_train_M.jsonl.gz + - aishell4_cuts_train_S.jsonl.gz + - alimeeting-far_cuts_train.jsonl.gz + - magicdata_cuts_train.jsonl.gz + - primewords_cuts_train.jsonl.gz + - stcmds_cuts_train.jsonl.gz + - thchs_30_cuts_train.jsonl.gz + - kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz + - kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz + - wenetspeech/cuts_L_fixed.jsonl.gz + """ + self.fbank_dir = Path(fbank_dir) + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # THCHS-30 + logging.info("Loading THCHS-30 in lazy mode") + thchs_30_cuts = load_manifest_lazy( + self.fbank_dir / "thchs_30_cuts_train.jsonl.gz" + ) + + # AISHELL-1 + logging.info("Loading Aishell-1 in lazy mode") + aishell_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_train.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) + + # AISHELL-4 + logging.info("Loading Aishell-4 in lazy mode") + aishell_4_L_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz" + ) + aishell_4_M_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz" + ) + aishell_4_S_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz" + ) + + # ST-CMDS + logging.info("Loading ST-CMDS in lazy mode") + stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz") + + # Primewords + logging.info("Loading Primewords in lazy mode") + primewords_cuts = load_manifest_lazy( + self.fbank_dir / "primewords_cuts_train.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData in lazy mode") + magicdata_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_train.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting in lazy mode") + alimeeting_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech in lazy mode") + wenetspeech_L_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech in lazy mode") + kespeech_1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz" + ) + kespeech_2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz" + ) + + return CutSet.mux( + thchs_30_cuts, + aishell_cuts, + aishell_2_cuts, + aishell_4_L_cuts, + aishell_4_M_cuts, + aishell_4_S_cuts, + alimeeting_cuts, + stcmds_cuts, + primewords_cuts, + magicdata_cuts, + wenetspeech_L_cuts, + kespeech_1_cuts, + kespeech_2_cuts, + weights=[ + len(thchs_30_cuts), + len(aishell_cuts), + len(aishell_2_cuts), + len(aishell_4_L_cuts), + len(aishell_4_M_cuts), + len(aishell_4_S_cuts), + len(alimeeting_cuts), + len(stcmds_cuts), + len(primewords_cuts), + len(magicdata_cuts), + len(wenetspeech_L_cuts), + len(kespeech_1_cuts), + len(kespeech_2_cuts), + ], + ) + + def dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + # WeNetSpeech + logging.info("Loading WeNetSpeech DEV set in lazy mode") + wenetspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz" + ) + + return wenetspeech_dev_cuts + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + # AISHELL + logging.info("Loading Aishell set in lazy mode") + aishell_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_test.jsonl.gz" + ) + aishell_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_dev.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # AISHELL-4 + logging.info("Loading Aishell-4 TEST set in lazy mode") + aishell4_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_test.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting set in lazy mode") + alimeeting_test_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz" + ) + alimeeting_eval_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData set in lazy mode") + magicdata_test_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_test.jsonl.gz" + ) + magicdata_dev_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech set in lazy mode") + kespeech_test_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz" + ) + kespeech_dev_phase1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" + ) + kespeech_dev_phase2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech set in lazy mode") + wenetspeech_test_meeting_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" + ) + wenetspeech_test_net_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz" + ) + wenetspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz" + ) + + return { + "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, + "aishell_test": aishell_test_cuts, + "aishell_dev": aishell_dev_cuts, + "ali-meeting_test": alimeeting_test_cuts, + "ali-meeting_eval": alimeeting_eval_cuts, + "aishell-4_test": aishell4_test_cuts, + "aishell-2_test": aishell2_test_cuts, + "aishell-2_dev": aishell2_dev_cuts, + "magicdata_test": magicdata_test_cuts, + "magicdata_dev": magicdata_dev_cuts, + "kespeech-asr_test": kespeech_test_cuts, + "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, + "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, + "wenetspeech-net_test": wenetspeech_test_net_cuts, + "wenetspeech_dev": wenetspeech_dev_cuts, + } + + def aishell_train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + logging.info("Loading Aishell-1 in lazy mode") + aishell_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_train.jsonl.gz" + ) + + return aishell_cuts + + def aishell_dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + logging.info("Loading Aishell set in lazy mode") + aishell_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_dev.jsonl.gz" + ) + + return aishell_dev_cuts + + def aishell_test_cuts(self) -> CutSet: + logging.info("About to get multidataset test cuts") + logging.info("Loading Aishell set in lazy mode") + aishell_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_test.jsonl.gz" + ) + + return { + "aishell_test": aishell_test_cuts, + } + + def aishell2_train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) + + return aishell_2_cuts + + def aishell2_dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + return aishell2_dev_cuts + + def aishell2_test_cuts(self) -> CutSet: + logging.info("About to get multidataset test cuts") + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + + return { + "aishell2_test": aishell2_test_cuts, + } + + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get multidataset test cuts") + logging.info("Loading WeNetSpeech set in lazy mode") + wenetspeech_test_meeting_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" + ) + + return { + "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, + } + + def speechio_test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + start_index = 0 + end_index = 26 + dataset_parts = [] + for i in range(start_index, end_index + 1): + idx = f"{i}".zfill(2) + dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") + + prefix = "speechio" + suffix = "jsonl.gz" + + results_dict = {} + for partition in dataset_parts: + path = f"{prefix}_cuts_{partition}.{suffix}" + + logging.info(f"Loading {path} set in lazy mode") + test_cuts = load_manifest_lazy(self.fbank_dir / path) + results_dict[partition] = test_cuts + + return results_dict diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt new file mode 100644 index 000000000..a07c7b157 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt @@ -0,0 +1,11 @@ +k2 +kaldialign +git+https://github.com/lhotse-speech/lhotse +sentencepiece +pypinyin +tensorboard +librosa +deepspeed +transformers>=4.37.0 +flash-attn +peft diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py new file mode 100755 index 000000000..5f224c984 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -0,0 +1,872 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# fine-tuning with whisper and Qwen2 +pip install huggingface_hub['cli'] +mkdir -p models/whisper models/qwen + +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct + +torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ + --max-duration 200 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True +""" + +import argparse +import copy +import logging +import os +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import deepspeed +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import transformers +import whisper +from asr_datamodule import AsrDataModule +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from label_smoothing import LabelSmoothingLoss +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector +from multi_dataset import MultiDataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoModelForCausalLM, AutoTokenizer +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward + +from icefall import diagnostics +from icefall.dist import get_rank, get_world_size +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +DEFAULT_SPEECH_TOKEN = "" + + +def set_batch_count(model: nn.Module, batch_count: float) -> None: + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--llm-path-or-name", + type=str, + default="/workspace/asr/Qwen1.5-0.5B-Chat", + help="Path or name of the large language model.", + ) + + parser.add_argument( + "--speech-encoder-path-or-name", + type=str, + default="whisper-large-v2", + help="Path or name of the speech encoder.", + ) + + parser.add_argument( + "--encoder-projector-ds-rate", + type=int, + default=8, + help="Downsample rate for the encoder projector.", + ) + parser.add_argument( + "--use-flash-attn", + type=str2bool, + default=True, + help="Whether to use flash attention.", + ) + + parser.add_argument( + "--use-lora", + type=str2bool, + default=False, + help="Whether to use lora to fine-tune llm.", + ) + + parser.add_argument( + "--unfreeze-llm", + type=str2bool, + default=False, + help="Whether to unfreeze llm during training.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper_qwen/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--pretrained-model-path", + type=str, + default=None, + help="""The path to the pretrained model if it is not None. Training will + start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt + """, + ) + + parser.add_argument( + "--sampler-state-dict-path", + type=str, + default=None, + help="""The path to the sampler state dict if it is not None. Training will start from this sampler state dict. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-aishell", + type=str2bool, + default=True, + help="Whether to only use aishell1 dataset for training.", + ) + + parser = deepspeed.add_config_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - frame_shift_ms: The frame shift in milliseconds. + - allowed_excess_duration_ratio: The allowed excess duration ratio. + - best_train_loss: The best training loss so far. + - best_valid_loss: The best validation loss so far. + - best_train_epoch: The epoch where the best training loss is achieved. + - best_valid_epoch: The epoch where the best validation loss is achieved. + - batch_idx_train: The batch index of the current batch. + - log_interval: Log training stats every `log_interval` batches. + - reset_interval: Reset the stats every `reset_interval` batches. + - valid_interval: Run validation every `valid_interval` batches. + - env_info: The environment information. + """ + params = AttributeDict( + { + "allowed_excess_duration_ratio": 0.1, + "subsampling_factor": 2, + "frame_shift_ms": 10, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 5000, + "env_info": get_env_info(), + } + ) + + return params + + +def compute_loss( + params: AttributeDict, + tokenizer: AutoTokenizer, + model: nn.Module, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute the loss for the given batch. + Args: + params: + It is returned by :func:`get_params`. + tokenizer: + The tokenizer used to encode the text. + model: + The model for training. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + Whether it is training. + Returns: + Return a tuple of two elements. The first element is the loss tensor. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + + def preprocess( + messages, + tokenizer: transformers.PreTrainedTokenizer, + max_len: int, + ) -> Dict: + """Preprocesses the data for supervised fine-tuning.""" + texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + for i, msg in enumerate(messages): + texts.append( + tokenizer.apply_chat_template( + msg, + tokenize=True, + chat_template=TEMPLATE, + add_generation_prompt=False, + padding="longest", # FIX me change padding to longest + max_length=max_len, + truncation=True, + ) + ) + # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] + else: + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] + input_ids = torch.tensor(texts, dtype=torch.int) + # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] + target_ids = input_ids.clone() + target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID + # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID + # first get the indices of the tokens + mask_prompt = True + if mask_prompt: + mask_indices = torch.where( + input_ids == tokenizer.convert_tokens_to_ids("assistant") + ) + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + # + 2 to skip: 'assistant', '\n' + target_ids[row, : col + 2] = IGNORE_TOKEN_ID + + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + return input_ids, attention_mask, target_ids + + def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + if normalize == "none": + return text + elif normalize == "m2met": + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = next(model.parameters()).device + feature = batch["inputs"] + + assert feature.ndim == 3 + feature = feature.to(device) + feature = feature.transpose(1, 2) # (N, C, T) + + batch_idx_train = params.batch_idx_train + supervisions = batch["supervisions"] + texts = batch["supervisions"]["text"] + # remove spaces in texts + texts = [normalize_text_alimeeting(text) for text in texts] + + messages = [] + for i, text in enumerate(texts): + message = [ + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, + {"role": "assistant", "content": text}, + ] + messages.append(message) + + input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128) + + target_ids = target_ids.type(torch.LongTensor) + input_ids = input_ids.type(torch.LongTensor) + + with torch.set_grad_enabled(is_training): + model_outputs, acc = model( + fbank=feature, + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + labels=target_ids.to(device), + ) + loss = model_outputs.loss + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + feature_lens = supervisions["num_frames"] + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["acc"] = ( + acc * info["frames"] + ) # WAR: to avoid normalization by the number of frames + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + tokenizer: AutoTokenizer, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.encoder_projector.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + tokenizer=tokenizer, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + if batch_idx != 0: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + client_state={}, + exclude_frozen_parameters=True, + ) + + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + exclude_frozen_parameters=True, + ) + # save sampler state dict into checkpoint + sampler_state_dict = train_dl.sampler.state_dict() + torch.save( + sampler_state_dict, + f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt", + ) + os.system( + f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" + ) + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + # deepspeed's backward() is different from torch's backward() + # in that it does not accept a loss tensor as input. + # It computes the loss internally. + model.backward(loss) + model.step() + + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if batch_idx % params.log_interval == 0: + try: + cur_lr = scheduler.get_last_lr()[0] + except: # noqa + cur_lr = 0.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info(params) + + logging.info("About to create model") + + replace_whisper_encoder_forward() + whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") + speech_encoder = whisper_model.encoder + speech_encoder_dim = whisper_model.dims.n_audio_state + for name, param in speech_encoder.named_parameters(): + param.requires_grad = False + speech_encoder.eval() + + tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + # torch_dtype=torch.bfloat16 FIX ME + torch_dtype = torch.float16 + tokenizer.padding_side = "left" + + else: + attn_implementation = "eager" + torch_dtype = torch.float16 + tokenizer.padding_side = "right" + + llm = AutoModelForCausalLM.from_pretrained( + params.llm_path_or_name, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype, + ) + + if not params.unfreeze_llm: + for name, param in llm.named_parameters(): + param.requires_grad = False + llm.eval() + else: + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + llm = get_peft_model(llm, lora_config) + llm.print_trainable_parameters() + + special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} + tokenizer.add_special_tokens(special_tokens_dict) + llm.config.pad_token_id = tokenizer.pad_token_id + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( + DEFAULT_SPEECH_TOKEN + ) + + encoder_projector = EncoderProjector( + speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate + ) + + model = SPEECH_LLM( + speech_encoder, + llm, + encoder_projector, + ) + + if params.pretrained_model_path: + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + logging.info("Trainable parameters (excluding model.eval modules):") + for name, param in model.named_parameters(): + if param.requires_grad: + logging.info(f"{name}: {param.shape}") + + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + else: + device = torch.device("cpu") + logging.info(f"Device: {device}") + model.to(device) + + assert params.deepspeed and world_size > 1 + logging.info("Using DeepSpeed") + model, optimizer, _, scheduler = deepspeed.initialize( + args=params, model=model, model_parameters=model.parameters() + ) + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + if params.use_aishell: + train_cuts = multi_dataset.aishell_train_cuts() + else: + train_cuts = multi_dataset.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + sampler_state_dict = None + if params.sampler_state_dict_path: + sampler_state_dict = torch.load(params.sampler_state_dict_path) + sampler_state_dict["max_duration"] = params.max_duration + # TODO: load sampler state dict + train_dl = data_module.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + if params.use_aishell: + valid_cuts = multi_dataset.aishell_dev_cuts() + else: + valid_cuts = multi_dataset.dev_cuts() + valid_dl = data_module.valid_dataloaders(valid_cuts) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + logging.info(f"start training from epoch {params.start_epoch}") + for epoch in range(params.start_epoch, params.num_epochs + 1): + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + tokenizer=tokenizer, + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}", + client_state={}, + exclude_frozen_parameters=True, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", + tag=f"epoch-{params.cur_epoch}", + exclude_frozen_parameters=True, + ) + # save sampler state dict into checkpoint + sampler_state_dict = train_dl.sampler.state_dict() + torch.save( + sampler_state_dict, + f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", + ) + + os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") + + logging.info("Done!") + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = get_world_size() + rank = get_rank() + + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + run(rank=rank, world_size=world_size, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/whisper_encoder_forward_monkey_patch.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/whisper_encoder_forward_monkey_patch.py new file mode 120000 index 000000000..2a7808921 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/whisper_encoder_forward_monkey_patch.py @@ -0,0 +1 @@ +../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py \ No newline at end of file From c13c7aa30b2845c2a6abfadaee2ebf98804d889b Mon Sep 17 00:00:00 2001 From: Seung Hyun Lee Date: Sun, 16 Jun 2024 17:20:44 +0900 Subject: [PATCH 176/216] Add Streaming Zipformer-Transducer recipe for KsponSpeech (#1651) --- egs/ksponspeech/ASR/README.md | 32 + egs/ksponspeech/ASR/RESULTS.md | 70 + egs/ksponspeech/ASR/local/__init__.py | 0 .../ASR/local/compute_fbank_ksponspeech.py | 185 +++ .../ASR/local/compute_fbank_musan.py | 158 +++ egs/ksponspeech/ASR/local/filter_cuts.py | 157 +++ egs/ksponspeech/ASR/local/train_bpe_model.py | 115 ++ .../ASR/local/validate_manifest.py | 101 ++ .../README.md | 1 + .../asr_datamodule.py | 415 ++++++ .../beam_search.py | 1 + .../decode.py | 993 +++++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export-onnx.py | 1 + .../export.py | 1 + .../joiner.py | 1 + .../model.py | 1 + .../onnx_check.py | 1 + .../onnx_model_wrapper.py | 1 + .../onnx_pretrained.py | 1 + .../optim.py | 1 + .../pretrained.py | 1 + .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 619 ++++++++ .../test_model.py | 187 +++ .../train.py | 1243 +++++++++++++++++ .../zipformer.py | 1 + egs/ksponspeech/ASR/shared | 1 + 32 files changed, 4294 insertions(+) create mode 100644 egs/ksponspeech/ASR/README.md create mode 100644 egs/ksponspeech/ASR/RESULTS.md create mode 100644 egs/ksponspeech/ASR/local/__init__.py create mode 100755 egs/ksponspeech/ASR/local/compute_fbank_ksponspeech.py create mode 100755 egs/ksponspeech/ASR/local/compute_fbank_musan.py create mode 100644 egs/ksponspeech/ASR/local/filter_cuts.py create mode 100755 egs/ksponspeech/ASR/local/train_bpe_model.py create mode 100755 egs/ksponspeech/ASR/local/validate_manifest.py create mode 100644 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/README.md create mode 100644 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100755 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 100755 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 100755 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py create mode 120000 egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py create mode 120000 egs/ksponspeech/ASR/shared diff --git a/egs/ksponspeech/ASR/README.md b/egs/ksponspeech/ASR/README.md new file mode 100644 index 000000000..2b02b9cca --- /dev/null +++ b/egs/ksponspeech/ASR/README.md @@ -0,0 +1,32 @@ +# Introduction +KsponSpeech is a large-scale spontaneous speech corpus of Korean. +This corpus contains 969 hours of open-domain dialog utterances, +spoken by about 2,000 native Korean speakers in a clean environment. + +All data were constructed by recording the dialogue of two people +freely conversing on a variety of topics and manually transcribing the utterances. + +The transcription provides a dual transcription consisting of orthography and pronunciation, +and disfluency tags for spontaneity of speech, such as filler words, repeated words, and word fragments. + +The original audio data has a pcm extension. +During preprocessing, it is converted into a file in the flac extension and saved anew. + +KsponSpeech is publicly available on an open data hub site of the Korea government. +The dataset must be downloaded manually. + +For more details, please visit: + + - Dataset: https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=123 + - Paper: https://www.mdpi.com/2076-3417/10/19/6936 + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers +There are various folders containing the name `transducer` in this folder. The following table lists the differences among them. + +| | Encoder | Decoder | Comment | +| ---------------------------------------- | -------------------- | ------------------ | ------------------------------------------------- | +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 | + +The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. \ No newline at end of file diff --git a/egs/ksponspeech/ASR/RESULTS.md b/egs/ksponspeech/ASR/RESULTS.md new file mode 100644 index 000000000..66edf8e66 --- /dev/null +++ b/egs/ksponspeech/ASR/RESULTS.md @@ -0,0 +1,70 @@ +## Results + +### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) + +#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +Number of model parameters: 79,022,891, i.e., 79.02 M + +##### Training on KsponSpeech (with MUSAN) + +Model: [johnBamma/icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12](https://huggingface.co/johnBamma/icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12) + +The CERs are: + +| decoding method | chunk size | eval_clean | eval_other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 10.21 | 11.07 | --epoch 30 --avg 9 | simulated streaming | +| greedy search | 320ms | 10.22 | 11.07 | --epoch 30 --avg 9 | chunk-wise | +| fast beam search | 320ms | 10.21 | 11.04 | --epoch 30 --avg 9 | simulated streaming | +| fast beam search | 320ms | 10.25 | 11.08 | --epoch 30 --avg 9 | chunk-wise | +| modified beam search | 320ms | 10.13 | 10.88 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 320ms | 10.1 | 10.93 | --epoch 30 --avg 9 | chunk-size | +| greedy search | 640ms | 9.94 | 10.82 | --epoch 30 --avg 9 | simulated streaming | +| greedy search | 640ms | 10.04 | 10.85 | --epoch 30 --avg 9 | chunk-wise | +| fast beam search | 640ms | 10.01 | 10.81 | --epoch 30 --avg 9 | simulated streaming | +| fast beam search | 640ms | 10.04 | 10.7 | --epoch 30 --avg 9 | chunk-wise | +| modified beam search | 640ms | 9.91 | 10.72 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 640ms | 9.92 | 10.72 | --epoch 30 --avg 9 | chunk-size | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command is: + +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 750 \ + --enable-musan True +``` + +The simulated streaming decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method $m +done +``` + +The streaming chunk-size decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m \ + --decode-chunk-len 32 \ + --num-decode-streams 2000 +done +``` \ No newline at end of file diff --git a/egs/ksponspeech/ASR/local/__init__.py b/egs/ksponspeech/ASR/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ksponspeech/ASR/local/compute_fbank_ksponspeech.py b/egs/ksponspeech/ASR/local/compute_fbank_ksponspeech.py new file mode 100755 index 000000000..b186c2296 --- /dev/null +++ b/egs/ksponspeech/ASR/local/compute_fbank_ksponspeech.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + parser.add_argument( + "--data-dir", + type=str, + default="data", + help="""Path of data directory""", + ) + + return parser.parse_args() + + +def compute_fbank_speechtools( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = False, + data_dir: Optional[str] = "data", +): + src_dir = Path(data_dir) / "manifests" + output_dir = Path(data_dir) / "fbank" + num_jobs = min(4, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ( + "train", + "dev", + "eval_clean", + "eval_other", + ) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = "ksponspeech" + suffix = "jsonl.gz" + logging.info(f"Read manifests...") + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + if torch.cuda.is_available(): + # Use cuda for fbank compute + device = "cuda" + else: + device = "cpu" + logging.info(f"Device: {device}") + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, device=device)) + + with get_executor() as ex: # Initialize the executor only once. + logging.info(f"Executor: {ex}") + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + + # Filter duration + cut_set = cut_set.filter( + lambda x: x.duration > 1 and x.sampling_rate == 16000 + ) + + if "train" in partition: + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + logging.info(f"Compute & Store features...") + if device == "cuda": + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + else: + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_speechtools( + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + data_dir=args.data_dir, + ) diff --git a/egs/ksponspeech/ASR/local/compute_fbank_musan.py b/egs/ksponspeech/ASR/local/compute_fbank_musan.py new file mode 100755 index 000000000..c0bdacfe5 --- /dev/null +++ b/egs/ksponspeech/ASR/local/compute_fbank_musan.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the musan dataset. +It looks for manifests in the directory `src_dir` (default is data/manifests). + +The generated fbank features are saved in data/fbank. +""" +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + MonoCut, + WhisperFbank, + WhisperFbankConfig, + combine, +) +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def is_cut_long(c: MonoCut) -> bool: + return c.duration > 5 + + +def compute_fbank_musan( + src_dir: str = "data/manifests", + num_mel_bins: int = 80, + whisper_fbank: bool = False, + output_dir: str = "data/fbank", +): + src_dir = Path(src_dir) + output_dir = Path(output_dir) + num_jobs = min(15, os.cpu_count()) + + dataset_parts = ( + "music", + "speech", + "noise", + ) + prefix = "musan" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" + + if musan_cuts_path.is_file(): + logging.info(f"{musan_cuts_path} already exists - skipping") + return + + logging.info("Extracting features for Musan") + + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + # create chunks of Musan with duration 5 - 10 seconds + musan_cuts = ( + CutSet.from_manifests( + recordings=combine(part["recordings"] for part in manifests.values()) + ) + .cut_into_windows(10.0) + .filter(is_cut_long) + .compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/musan_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + ) + musan_cuts.to_file(musan_cuts_path) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--src-dir", + type=str, + default="data/manifests", + help="Source manifests directory.", + ) + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + compute_fbank_musan( + src_dir=args.src_dir, + num_mel_bins=args.num_mel_bins, + whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, + ) diff --git a/egs/ksponspeech/ASR/local/filter_cuts.py b/egs/ksponspeech/ASR/local/filter_cuts.py new file mode 100644 index 000000000..f081da5df --- /dev/null +++ b/egs/ksponspeech/ASR/local/filter_cuts.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_5000/bpe.model \ + --in-cuts data/fbank/speechtools_cuts_test.jsonl.gz \ + --out-cuts data/fbank-filtered/speechtools_cuts_test.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ksponspeech/ASR/local/train_bpe_model.py b/egs/ksponspeech/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..5979d5b98 --- /dev/null +++ b/egs/ksponspeech/ASR/local/train_bpe_model.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path +from typing import Dict + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def generate_tokens(lang_dir: Path): + """ + Generate the tokens.txt from a bpe model. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f: + for sym, i in token2id.items(): + f.write(f"{sym} {i}\n") + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + generate_tokens(lang_dir) + + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/local/validate_manifest.py b/egs/ksponspeech/ASR/local/validate_manifest.py new file mode 100755 index 000000000..98f273419 --- /dev/null +++ b/egs/ksponspeech/ASR/local/validate_manifest.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within cut time bounds + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/speechtools_cuts_train.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.dataset.speech_recognition import validate_for_asr + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + tol = 2e-3 # same tolerance as in 'validate_for_asr()' + s = c.supervisions[0] + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + if s.start < -tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} must not be negative." + ) + if s.start > tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`." + ) + if c.start + s.end > c.end + tol: + raise ValueError( + f"{c.id}: Supervision end time {c.start+s.end} is larger " + f"than cut end time {c.end}" + ) + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + # Validation from K2 training + # - checks supervision start is 0 + # - checks supervision.duration is not longer than cut.duration + # - there is tolerance 2ms + validate_for_asr(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/README.md new file mode 100644 index 000000000..644bf9564 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/README.md @@ -0,0 +1 @@ +This recipe implements Streaming Zipformer-Transducer model. \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 100644 index 000000000..9a5b3fc52 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1,415 @@ +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class KsponSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader. + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts.") + return load_manifest_lazy( + self.args.manifest_dir / "ksponspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ksponspeech_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def eval_clean_cuts(self) -> CutSet: + logging.info("About to get eval_clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ksponspeech_cuts_eval_clean.jsonl.gz" + ) + + @lru_cache() + def eval_other_cuts(self) -> CutSet: + logging.info("About to get eval_other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ksponspeech_cuts_eval_other.jsonl.gz" + ) diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..d7349b0a3 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..0f3f1c1ab --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,993 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import KsponSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += 30 + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, 30), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"beam_size_{params.beam_size}_{key}"] = hyps + return ans + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + ngram_lm: + A n-gram LM to be used for LODR. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out CERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + KsponSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ksponspeech = KsponSpeechAsrDataModule(args) + + eval_clean_cuts = ksponspeech.eval_clean_cuts() + eval_other_cuts = ksponspeech.eval_other_cuts() + + eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts) + eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts) + + test_sets = ["eval_clean", "eval_other"] + test_dl = [eval_clean_dl, eval_other_dl] + import time + + for test_set, test_dl in zip(test_sets, test_dl): + start = time.time() + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + logging.info(f"Elasped time for {test_set}: {time.time() - start}") + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..1ce277aa6 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..cb673b3eb --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 120000 index 000000000..57a0cd0a0 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 120000 index 000000000..2acafdc61 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..482ebcfef --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..16c2bf28d --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 120000 index 000000000..28bf7bb82 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 120000 index 000000000..c8548d459 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 120000 index 000000000..ae4d9bb04 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..522bbaff9 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 120000 index 000000000..9510b8fde --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..a7ef73bcb --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..566c317ff --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..2adf271c1 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..d777b769c --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import KsponSpeechAsrDataModule +from decode_stream import DecodeStream +from lhotse import CutSet, Fbank, FbankConfig +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankConfig( + device=device, + dither=0.0, + snip_edges=False, + sampling_rate=16000, + num_mel_bins=80, + high_freq=-400.0, + ) + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank.extract(samples.to(device), sampling_rate=16000) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out CERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + KsponSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + ksponspeech = KsponSpeechAsrDataModule(args) + + eval_clean_cuts = ksponspeech.eval_clean_cuts() + eval_other_cuts = ksponspeech.eval_other_cuts() + + test_sets = ["eval_clean", "eval_other"] + test_cuts = [eval_clean_cuts, eval_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..a465758f5 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/ksponspeech/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_small(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,2,2,2,2" + params.feedforward_dims = "256,256,512,512,256" + params.nhead = "4,4,4,4,4" + params.encoder_dims = "128,128,128,128,128" + params.attention_dims = "96,96,96,96,96" + params.encoder_unmasked_dims = "96,96,96,96,96" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 320 + params.joiner_dim = 320 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + import pdb + + pdb.set_trace() + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model_small() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..bf50bf5ea --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1243 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import KsponSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise_grad_scale_is_too_small_error(cur_grad_scale) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ksponspeech = KsponSpeechAsrDataModule(args) + + train_cuts = ksponspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 32.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + # train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = ksponspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = ksponspeech.dev_cuts() + + # valid_cuts = valid_cuts.filter(remove_short_and_long_utt) + valid_dl = ksponspeech.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + KsponSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/shared b/egs/ksponspeech/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/ksponspeech/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From 1f5c0a87b9b89fb25d5ea853d07335119767ebed Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 16 Jun 2024 19:15:09 +0800 Subject: [PATCH 177/216] Add CI for ksponspeech (#1655) --- .github/scripts/ksponspeech/ASR/run.sh | 72 +++++++++++++++ .github/workflows/ksponspeech.yml | 118 +++++++++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100755 .github/scripts/ksponspeech/ASR/run.sh create mode 100644 .github/workflows/ksponspeech.yml diff --git a/.github/scripts/ksponspeech/ASR/run.sh b/.github/scripts/ksponspeech/ASR/run.sh new file mode 100755 index 000000000..068c22dfc --- /dev/null +++ b/.github/scripts/ksponspeech/ASR/run.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/ksponspeech/ASR + + +function test_pretrained() { + git lfs install + git clone https://huggingface.co/johnBamma/icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12 + repo=icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12 + pushd $repo + mkdir test_wavs + cd test_wavs + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/0.wav + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/1.wav + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/2.wav + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/3.wav + cd ../exp + ln -s pretrained.pt epoch-99.pt + ls -lh + popd + + log 'test pretrained.py' + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_5000/tokens.txt \ + --method greedy_search \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav + + log 'test export-onnx.py' + + ./pruned_transducer_stateless7_streaming/export-onnx.py \ + --tokens $repo/data/lang_bpe_5000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + + ls -lh $repo/exp + + ls -lh $repo/data/lang_bpe_5000/ + + log 'test exported onnx models' + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_5000/tokens.txt \ + $repo/test_wavs/0.wav + + dst=/tmp/model1 + mkdir -p $dst + + cp -v $repo/exp/*.onnx $dst + cp -v $repo/exp/*.onnx $dst + cp -v $repo/data/lang_bpe_5000/tokens.txt $dst + cp -v $repo/data/lang_bpe_5000/bpe.model $dst + rm -rf $repo +} + +test_pretrained diff --git a/.github/workflows/ksponspeech.yml b/.github/workflows/ksponspeech.yml new file mode 100644 index 000000000..2e1441c06 --- /dev/null +++ b/.github/workflows/ksponspeech.yml @@ -0,0 +1,118 @@ +name: ksponspeech + +on: + push: + branches: + - ksponspeech + + workflow_dispatch: + +jobs: + ksponspeech: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Test + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/ksponspeech/ASR/run.sh + + - name: Show model files + shell: bash + run: | + src=/tmp/model1 + ls -lh $src + + - name: Upload model to huggingface + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + src=/tmp/model1 + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + rm -rf hf + export GIT_LFS_SKIP_SMUDGE=1 + export GIT_CLONE_PROTECTION_ACTIVE=false + + git clone https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16 hf + cd hf + git fetch + git pull + git merge -m "merge remote" --ff origin main + cp -v $src/* ./ + ls -lh + git lfs track "bpe.model" + git lfs track "*.onnx" + cp -av test_wavs $src/ + git add . + git status + git commit -m "update models" + git status + + git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16 main || true + rm -rf hf + + - name: Prepare for release + shell: bash + run: | + src=/tmp/model1 + d=sherpa-onnx-streaming-zipformer-korean-2024-06-16 + mv $src ./$d + tar cjvf ${d}.tar.bz2 $d + ls -lh + + - name: Release exported onnx models + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: sherpa-onnx-*.tar.bz2 + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models From 2e05663fbbae3cf48ef55c42a973004df0b3ae38 Mon Sep 17 00:00:00 2001 From: Seung Hyun Lee Date: Tue, 18 Jun 2024 17:54:39 +0900 Subject: [PATCH 178/216] Add prepare.sh for KsponSpeech recipe. (#1656) --- egs/ksponspeech/ASR/prepare.sh | 162 +++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100755 egs/ksponspeech/ASR/prepare.sh diff --git a/egs/ksponspeech/ASR/prepare.sh b/egs/ksponspeech/ASR/prepare.sh new file mode 100755 index 000000000..2c5cc8b49 --- /dev/null +++ b/egs/ksponspeech/ASR/prepare.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=0 +stop_stage=100 + +# Note: This script just prepare the minimal requirements that needed by a +# transducer training with bpe units. +# +# We assume dl_dir (download dir) contains the following +# directories and files. +# This script downloads only musan dataset automatically. +# +# - $dl_dir/KsponSpeech +# This script doesn't download KsponSpeech dataset automatically. +# For more details, please visit: +# Dataset: https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=123 +# Paper: https://www.mdpi.com/2076-3417/10/19/6936 +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 5000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +data=$PWD/data + +. shared/parse_options.sh || exit 1 + +mkdir -p $data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "Running prepare.sh" + +log "dl_dir: $dl_dir" + + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download MUSAN data" + # Befor you run this script, you must get the KsponSpeech dataset + # from https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=123 + # If you have pre-downloaded it to /path/to/KsponSpeech, + # you can create a symlink + # + # ln -svf /path/to/KsponSpeech $dl_dir/KsponSpeech + # + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/musan + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare KsponSpeech manifest" + # We assume that you have downloaded the KsponSpeech corpus + # to $dl_dir/KsponSpeech + mkdir -p $data/manifests + if [ ! -e $data/manifests/.ksponspeech.done ]; then + lhotse prepare ksponspeech -j $nj $dl_dir/KsponSpeech $data/manifests + touch $data/manifests/.ksponspeech.done + fi +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p $data/manifests + if [ ! -e $data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan $data/manifests + touch $data/manifests/.musan.done + fi +fi + + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for KsponSpeech" + mkdir -p $data/fbank + if [ ! -e $data/fbank/.ksponspeech.done ]; then + ./local/compute_fbank_ksponspeech.py --data-dir $data + touch $data/fbank/.ksponspeech.done + fi + + if [ ! -e $data/fbank/.ksponspeech-validated.done ]; then + log "Validating data/fbank for KsponSpeech" + parts=( + train + dev + eval_clean + eval_other + ) + for part in ${parts[@]}; do + ./local/validate_manifest.py \ + $data/fbank/ksponspeech_cuts_${part}.jsonl.gz + done + touch $data/fbank/.ksponspeech-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p $data/fbank + if [ ! -e $data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py \ + --src-dir $data/manifests \ + --output-dir $data/fbank + touch $data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=$data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + files=$( + find "$data/fbank" -name "ksponspeech_cuts_*.jsonl.gz" + ) + gunzip -c ${files} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + done +fi From ff2bef9e501a4b5ebfec04cbfe8afa2e8bea4b40 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Wed, 19 Jun 2024 11:10:31 +0800 Subject: [PATCH 179/216] update multi-hans whisper-qwen-1.5b results (#1657) --- egs/speech_llm/ASR_LLM/RESULTS.md | 29 +++++++++++++++++---- egs/speechio/ASR/local/normalize_results.py | 7 +++-- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md index dc2479054..05c0ffd27 100644 --- a/egs/speech_llm/ASR_LLM/RESULTS.md +++ b/egs/speech_llm/ASR_LLM/RESULTS.md @@ -2,12 +2,31 @@ ### whisper_llm_zh finetuning results -| Training Dataset | Speech Encoder | LLM | Projector |Comment | CER | -| -------------------------| ----------------|------|--------------------------------------------------|-----|--| -| Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 Test 3.62% | - +|Model| Training Dataset | Speech Encoder | LLM | Projector | +|-| -------------------------| ----------------|------|---------------| +|[yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| +| [yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B) |Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample|| +CER Details: +| Model | [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | [yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B) | +|-------|------------------------------------------------|----------------------------------------------------| +| Split | Greedy Search | Greedy Search | +| aishell-1 dev | - | 0.66 | +| aishell-1 test | 3.62 | 0.68 | +| aishell-2 dev | - | 2.67 | +| aishell-2 test | - | 2.94 | +| aishell-4 test | - | 16.20 | +| alimeeting eval | - | 30.86 | +| alimeeting test | - | 40.50 | +| magicdata dev | - | 2.50 | +| magicdata test | - | 1.70 | +| kespeech-asr dev phase1 | - | 6.22 | +| kespeech-asr dev phase2 | - | 2.18 | +| kespeech-asr test | - | 6.59 | +| WenetSpeech dev | - | 4.59 | +| WenetSpeech test_meeting | - | 6.41 | +| WenetSpeech tes_net | - | 6.63 | +| SPEECHIO Avg 001-026 | - | 4.80 | Command for training is: ```bash pip install -r whisper_llm_zh/requirements.txt diff --git a/egs/speechio/ASR/local/normalize_results.py b/egs/speechio/ASR/local/normalize_results.py index 14eb1bb2f..79d886617 100755 --- a/egs/speechio/ASR/local/normalize_results.py +++ b/egs/speechio/ASR/local/normalize_results.py @@ -16,12 +16,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This file uses whisper and zipformer decoding results to generate fusion decoding results. -Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors, -we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors. +This file uses speech io offcial pipline to normalize the decoding results. +https://github.com/SpeechColab/Leaderboard/blob/master/utils/textnorm_zh.py Usage: - python whisper_zipformer_fusion.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm + python normalize_results.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm """ import argparse From 3059eb4511c4305d9030cd34c03b26d524efacf7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 21 Jun 2024 11:10:14 +0800 Subject: [PATCH 180/216] Fix doc URLs (#1660) --- README.md | 13 +++++++++---- egs/aidatatang_200zh/ASR/README.md | 2 +- egs/aishell/ASR/tdnn_lstm_ctc/README.md | 2 +- egs/librispeech/ASR/README.md | 2 +- egs/librispeech/ASR/conformer_ctc/README.md | 4 ++-- egs/timit/ASR/README.md | 2 +- egs/yesno/ASR/README.md | 2 +- egs/yesno/ASR/tdnn/README.md | 2 +- 8 files changed, 17 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 770066166..31e514606 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,12 @@ Please refer to [document](https://k2-fsa.github.io/icefall/huggingface/spaces.h # Installation -Please refer to [document](https://icefall.readthedocs.io/en/latest/installation/index.html) +Please refer to [document](https://k2-fsa.github.io/icefall/installation/index.html) for installation. # Recipes -Please refer to [document](https://icefall.readthedocs.io/en/latest/recipes/index.html) +Please refer to [document](https://k2-fsa.github.io/icefall/recipes/index.html) for more details. ## ASR: Automatic Speech Recognition @@ -77,7 +77,7 @@ The [LibriSpeech][librispeech] recipe supports the most comprehensive set of mod #### Whisper - [OpenAi Whisper](https://arxiv.org/abs/2212.04356) (We support fine-tuning on AiShell-1.) -If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details. +If you are willing to contribute to icefall, please refer to [contributing](https://k2-fsa.github.io/icefall/contributing/index.html) for more details. We would like to highlight the performance of some of the recipes here. @@ -343,7 +343,12 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt Once you have trained a model in icefall, you may want to deploy it with C++ without Python dependencies. -Please refer to the [document](https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/librispeech/conformer_ctc.html#deployment-with-c) +Please refer to + + - https://k2-fsa.github.io/icefall/model-export/export-with-torch-jit-script.html + - https://k2-fsa.github.io/icefall/model-export/export-onnx.html + - https://k2-fsa.github.io/icefall/model-export/export-ncnn.html + for how to do this. We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++. diff --git a/egs/aidatatang_200zh/ASR/README.md b/egs/aidatatang_200zh/ASR/README.md index b85895a09..035139d17 100644 --- a/egs/aidatatang_200zh/ASR/README.md +++ b/egs/aidatatang_200zh/ASR/README.md @@ -6,7 +6,7 @@ The main repositories are list below, we will update the training and decoding s k2: https://github.com/k2-fsa/k2 icefall: https://github.com/k2-fsa/icefall lhotse: https://github.com/lhotse-speech/lhotse -* Install k2 and lhotse, k2 installation guide refers to https://k2.readthedocs.io/en/latest/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall. +* Install k2 and lhotse, k2 installation guide refers to https://k2-fsa.github.io/k2/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall. * Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above. ``` git clone https://github.com/k2-fsa/icefall diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/README.md b/egs/aishell/ASR/tdnn_lstm_ctc/README.md index a2d80a785..c003fd419 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/README.md +++ b/egs/aishell/ASR/tdnn_lstm_ctc/README.md @@ -1,4 +1,4 @@ Please visit - + for how to run this recipe. diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 080f81c91..93fef7a07 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,6 +1,6 @@ # Introduction -Please refer to for how to run models in this recipe. +Please refer to for how to run models in this recipe. [./RESULTS.md](./RESULTS.md) contains the latest results. diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index 37ace4204..1bccccc73 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -1,7 +1,7 @@ ## Introduction Please visit - + for how to run this recipe. ## How to compute framewise alignment information @@ -9,7 +9,7 @@ for how to run this recipe. ### Step 1: Train a model Please use `conformer_ctc/train.py` to train a model. -See +See for how to do it. ### Step 2: Compute framewise alignment diff --git a/egs/timit/ASR/README.md b/egs/timit/ASR/README.md index d493fc479..f700fab9e 100644 --- a/egs/timit/ASR/README.md +++ b/egs/timit/ASR/README.md @@ -1,3 +1,3 @@ -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/yesno/ASR/README.md b/egs/yesno/ASR/README.md index 38b491fc6..c9a2b56b1 100644 --- a/egs/yesno/ASR/README.md +++ b/egs/yesno/ASR/README.md @@ -10,5 +10,5 @@ get the following WER: ``` Please refer to - + for detailed instructions. diff --git a/egs/yesno/ASR/tdnn/README.md b/egs/yesno/ASR/tdnn/README.md index 2b6116f0a..1b7ddcaf1 100644 --- a/egs/yesno/ASR/tdnn/README.md +++ b/egs/yesno/ASR/tdnn/README.md @@ -2,7 +2,7 @@ ## How to run this recipe You can find detailed instructions by visiting - + It describes how to run this recipe and how to use a pre-trained model with `./pretrained.py`. From 6f102d34704cd2fe1a4b695e286f2d07e4c00551 Mon Sep 17 00:00:00 2001 From: Seung Hyun Lee Date: Mon, 24 Jun 2024 15:07:37 +0900 Subject: [PATCH 181/216] Add non-streaming Zipformer recipe for KsponSpeech (#1664) --- egs/ksponspeech/ASR/README.md | 3 +- egs/ksponspeech/ASR/RESULTS.md | 52 +- egs/ksponspeech/ASR/zipformer/README.md | 1 + .../ASR/zipformer/asr_datamodule.py | 1 + egs/ksponspeech/ASR/zipformer/beam_search.py | 1 + egs/ksponspeech/ASR/zipformer/ctc_decode.py | 844 ++++++++++ egs/ksponspeech/ASR/zipformer/decode.py | 1050 +++++++++++++ .../ASR/zipformer/decode_stream.py | 1 + egs/ksponspeech/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-ctc.py | 1 + .../zipformer/export-onnx-streaming-ctc.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/ksponspeech/ASR/zipformer/export-onnx.py | 1 + egs/ksponspeech/ASR/zipformer/export.py | 1 + .../ASR/zipformer/generate_averaged_model.py | 1 + egs/ksponspeech/ASR/zipformer/joiner.py | 1 + egs/ksponspeech/ASR/zipformer/model.py | 1 + egs/ksponspeech/ASR/zipformer/onnx_check.py | 1 + egs/ksponspeech/ASR/zipformer/onnx_decode.py | 1 + .../onnx_pretrained-streaming-ctc.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc.py | 1 + egs/ksponspeech/ASR/zipformer/optim.py | 1 + egs/ksponspeech/ASR/zipformer/pretrained.py | 1 + .../ASR/zipformer/pretrained_ctc.py | 1 + egs/ksponspeech/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 856 ++++++++++ egs/ksponspeech/ASR/zipformer/subsampling.py | 1 + egs/ksponspeech/ASR/zipformer/test_scaling.py | 1 + .../ASR/zipformer/test_subsampling.py | 1 + egs/ksponspeech/ASR/zipformer/train.py | 1381 +++++++++++++++++ egs/ksponspeech/ASR/zipformer/zipformer.py | 1 + 36 files changed, 4212 insertions(+), 4 deletions(-) create mode 100644 egs/ksponspeech/ASR/zipformer/README.md create mode 120000 egs/ksponspeech/ASR/zipformer/asr_datamodule.py create mode 120000 egs/ksponspeech/ASR/zipformer/beam_search.py create mode 100755 egs/ksponspeech/ASR/zipformer/ctc_decode.py create mode 100755 egs/ksponspeech/ASR/zipformer/decode.py create mode 120000 egs/ksponspeech/ASR/zipformer/decode_stream.py create mode 120000 egs/ksponspeech/ASR/zipformer/decoder.py create mode 120000 egs/ksponspeech/ASR/zipformer/encoder_interface.py create mode 120000 egs/ksponspeech/ASR/zipformer/export-onnx-ctc.py create mode 120000 egs/ksponspeech/ASR/zipformer/export-onnx-streaming-ctc.py create mode 120000 egs/ksponspeech/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/ksponspeech/ASR/zipformer/export-onnx.py create mode 120000 egs/ksponspeech/ASR/zipformer/export.py create mode 120000 egs/ksponspeech/ASR/zipformer/generate_averaged_model.py create mode 120000 egs/ksponspeech/ASR/zipformer/joiner.py create mode 120000 egs/ksponspeech/ASR/zipformer/model.py create mode 120000 egs/ksponspeech/ASR/zipformer/onnx_check.py create mode 120000 egs/ksponspeech/ASR/zipformer/onnx_decode.py create mode 120000 egs/ksponspeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py create mode 120000 egs/ksponspeech/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/ksponspeech/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/ksponspeech/ASR/zipformer/onnx_pretrained_ctc.py create mode 120000 egs/ksponspeech/ASR/zipformer/optim.py create mode 120000 egs/ksponspeech/ASR/zipformer/pretrained.py create mode 120000 egs/ksponspeech/ASR/zipformer/pretrained_ctc.py create mode 120000 egs/ksponspeech/ASR/zipformer/scaling.py create mode 120000 egs/ksponspeech/ASR/zipformer/scaling_converter.py create mode 120000 egs/ksponspeech/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/ksponspeech/ASR/zipformer/streaming_decode.py create mode 120000 egs/ksponspeech/ASR/zipformer/subsampling.py create mode 120000 egs/ksponspeech/ASR/zipformer/test_scaling.py create mode 120000 egs/ksponspeech/ASR/zipformer/test_subsampling.py create mode 100755 egs/ksponspeech/ASR/zipformer/train.py create mode 120000 egs/ksponspeech/ASR/zipformer/zipformer.py diff --git a/egs/ksponspeech/ASR/README.md b/egs/ksponspeech/ASR/README.md index 2b02b9cca..44a75ca27 100644 --- a/egs/ksponspeech/ASR/README.md +++ b/egs/ksponspeech/ASR/README.md @@ -27,6 +27,7 @@ There are various folders containing the name `transducer` in this folder. The f | | Encoder | Decoder | Comment | | ---------------------------------------- | -------------------- | ------------------ | ------------------------------------------------- | -| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 | +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 | +| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. \ No newline at end of file diff --git a/egs/ksponspeech/ASR/RESULTS.md b/egs/ksponspeech/ASR/RESULTS.md index 66edf8e66..5d8001062 100644 --- a/egs/ksponspeech/ASR/RESULTS.md +++ b/egs/ksponspeech/ASR/RESULTS.md @@ -19,13 +19,13 @@ The CERs are: | fast beam search | 320ms | 10.21 | 11.04 | --epoch 30 --avg 9 | simulated streaming | | fast beam search | 320ms | 10.25 | 11.08 | --epoch 30 --avg 9 | chunk-wise | | modified beam search | 320ms | 10.13 | 10.88 | --epoch 30 --avg 9 | simulated streaming | -| modified beam search | 320ms | 10.1 | 10.93 | --epoch 30 --avg 9 | chunk-size | +| modified beam search | 320ms | 10.1 | 10.93 | --epoch 30 --avg 9 | chunk-wize | | greedy search | 640ms | 9.94 | 10.82 | --epoch 30 --avg 9 | simulated streaming | | greedy search | 640ms | 10.04 | 10.85 | --epoch 30 --avg 9 | chunk-wise | | fast beam search | 640ms | 10.01 | 10.81 | --epoch 30 --avg 9 | simulated streaming | | fast beam search | 640ms | 10.04 | 10.7 | --epoch 30 --avg 9 | chunk-wise | | modified beam search | 640ms | 9.91 | 10.72 | --epoch 30 --avg 9 | simulated streaming | -| modified beam search | 640ms | 9.92 | 10.72 | --epoch 30 --avg 9 | chunk-size | +| modified beam search | 640ms | 9.92 | 10.72 | --epoch 30 --avg 9 | chunk-wize | Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. @@ -67,4 +67,50 @@ for m in greedy_search modified_beam_search fast_beam_search; do --decode-chunk-len 32 \ --num-decode-streams 2000 done -``` \ No newline at end of file +``` + +### zipformer (Zipformer + pruned statelss transducer) + +#### [zipformer](./zipformer) + +Number of model parameters: 74,778,511, i.e., 74.78 M + +##### Training on KsponSpeech (with MUSAN) + +Model: [johnBamma/icefall-asr-ksponspeech-zipformer-2024-06-24](https://huggingface.co/johnBamma/icefall-asr-ksponspeech-zipformer-2024-06-24) + +The CERs are: + +| decoding method | eval_clean | eval_other | comment | +|----------------------|------------|------------|---------------------| +| greedy search | 10.60 | 11.56 | --epoch 30 --avg 9 | +| fast beam search | 10.59 | 11.54 | --epoch 30 --avg 9 | +| modified beam search | 10.35 | 11.35 | --epoch 30 --avg 9 | + +The training command is: + +```bash +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 750 \ + --enable-musan True \ + --base-lr 0.035 +``` + +NOTICE: I decreased `base_lr` from 0.045(default) to 0.035, Because of `RuntimeError: grad_scale is too small`. + +The decoding command is: + +```bash +for m in greedy_search fast_beam_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir zipformer/exp \ + --decoding-method $m +done +``` diff --git a/egs/ksponspeech/ASR/zipformer/README.md b/egs/ksponspeech/ASR/zipformer/README.md new file mode 100644 index 000000000..c8c2104cd --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/README.md @@ -0,0 +1 @@ +This recipe implements Zipformer model. diff --git a/egs/ksponspeech/ASR/zipformer/asr_datamodule.py b/egs/ksponspeech/ASR/zipformer/asr_datamodule.py new file mode 120000 index 000000000..46254a1f2 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/asr_datamodule.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/beam_search.py b/egs/ksponspeech/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/ctc_decode.py b/egs/ksponspeech/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..9f04f5d4d --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/ctc_decode.py @@ -0,0 +1,844 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(2) 1best +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(3) nbest +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(4) nbest-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(5) whole-lattice-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import KsponSpeechAsrDataModule +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.6, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out CERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + cer = write_error_stats(f, f"{test_set_name}-{key}", results, compute_CER=True) + test_set_cers[key] = cer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + KsponSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + + ksponspeech = KsponSpeechAsrDataModule(args) + + eval_clean_cuts = ksponspeech.eval_clean_cuts() + eval_other_cuts = ksponspeech.eval_other_cuts() + + eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts) + eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts) + + test_sets = ["eval_clean", "eval_other"] + test_dl = [eval_clean_dl, eval_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/zipformer/decode.py b/egs/ksponspeech/ASR/zipformer/decode.py new file mode 100755 index 000000000..be42898b7 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/decode.py @@ -0,0 +1,1050 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import KsponSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out CERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"cer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + KsponSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ksponspeech = KsponSpeechAsrDataModule(args) + + eval_clean_cuts = ksponspeech.eval_clean_cuts() + eval_other_cuts = ksponspeech.eval_other_cuts() + + eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts) + eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts) + + test_sets = ["eval_clean", "eval_other"] + test_dl = [eval_clean_dl, eval_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/zipformer/decode_stream.py b/egs/ksponspeech/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/decoder.py b/egs/ksponspeech/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/encoder_interface.py b/egs/ksponspeech/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/export-onnx-ctc.py b/egs/ksponspeech/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 000000000..f9d756352 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/ksponspeech/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 120000 index 000000000..652346001 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/export-onnx-streaming.py b/egs/ksponspeech/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/export-onnx.py b/egs/ksponspeech/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/export.py b/egs/ksponspeech/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/generate_averaged_model.py b/egs/ksponspeech/ASR/zipformer/generate_averaged_model.py new file mode 120000 index 000000000..5a015ee6c --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/joiner.py b/egs/ksponspeech/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/model.py b/egs/ksponspeech/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/onnx_check.py b/egs/ksponspeech/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/onnx_decode.py b/egs/ksponspeech/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/ksponspeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py new file mode 120000 index 000000000..d623a8462 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/ksponspeech/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/onnx_pretrained.py b/egs/ksponspeech/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/ksponspeech/ASR/zipformer/onnx_pretrained_ctc.py new file mode 120000 index 000000000..a3183ebf6 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/optim.py b/egs/ksponspeech/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/pretrained.py b/egs/ksponspeech/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/pretrained_ctc.py b/egs/ksponspeech/ASR/zipformer/pretrained_ctc.py new file mode 120000 index 000000000..c2f6f6fc3 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/scaling.py b/egs/ksponspeech/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/scaling_converter.py b/egs/ksponspeech/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/streaming_beam_search.py b/egs/ksponspeech/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/streaming_decode.py b/egs/ksponspeech/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..9811bac7c --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/streaming_decode.py @@ -0,0 +1,856 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import KsponSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out CERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"cer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + KsponSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + ksponspeech = KsponSpeechAsrDataModule(args) + + eval_clean_cuts = ksponspeech.eval_clean_cuts() + eval_other_cuts = ksponspeech.eval_other_cuts() + + test_sets = ["eval_clean", "eval_other"] + test_cuts = [eval_clean_cuts, eval_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/zipformer/subsampling.py b/egs/ksponspeech/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/test_scaling.py b/egs/ksponspeech/ASR/zipformer/test_scaling.py new file mode 120000 index 000000000..715798436 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/test_scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/test_subsampling.py b/egs/ksponspeech/ASR/zipformer/test_subsampling.py new file mode 120000 index 000000000..bf0ee3d11 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/test_subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py new file mode 100755 index 000000000..5957fe1fb --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -0,0 +1,1381 @@ +#!/usr/bin/env python3 +# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import KsponSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info( + f"Caught exception: {e}." + ) + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ksponspeech = KsponSpeechAsrDataModule(args) + + train_cuts = ksponspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 32.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + # train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = ksponspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = ksponspeech.dev_cuts() + valid_dl = ksponspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + KsponSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ksponspeech/ASR/zipformer/zipformer.py b/egs/ksponspeech/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/ksponspeech/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 031f8927961c667d1f713bc1340ba2888b157b4d Mon Sep 17 00:00:00 2001 From: Seung Hyun Lee Date: Mon, 24 Jun 2024 16:28:09 +0900 Subject: [PATCH 182/216] Reformat by black non-streaming zipformer recipe for ksponspeech (#1665) --- egs/ksponspeech/ASR/zipformer/ctc_decode.py | 8 +++++--- egs/ksponspeech/ASR/zipformer/decode.py | 6 +++++- egs/ksponspeech/ASR/zipformer/streaming_decode.py | 6 +++++- egs/ksponspeech/ASR/zipformer/train.py | 4 +--- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/egs/ksponspeech/ASR/zipformer/ctc_decode.py b/egs/ksponspeech/ASR/zipformer/ctc_decode.py index 9f04f5d4d..30bf1610b 100755 --- a/egs/ksponspeech/ASR/zipformer/ctc_decode.py +++ b/egs/ksponspeech/ASR/zipformer/ctc_decode.py @@ -571,7 +571,9 @@ def save_results( # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: - cer = write_error_stats(f, f"{test_set_name}-{key}", results, compute_CER=True) + cer = write_error_stats( + f, f"{test_set_name}-{key}", results, compute_CER=True + ) test_set_cers[key] = cer logging.info("Wrote detailed error stats to {}".format(errs_filename)) @@ -807,7 +809,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - + ksponspeech = KsponSpeechAsrDataModule(args) eval_clean_cuts = ksponspeech.eval_clean_cuts() @@ -815,7 +817,7 @@ def main(): eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts) eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts) - + test_sets = ["eval_clean", "eval_other"] test_dl = [eval_clean_dl, eval_other_dl] diff --git a/egs/ksponspeech/ASR/zipformer/decode.py b/egs/ksponspeech/ASR/zipformer/decode.py index be42898b7..5c21abb79 100755 --- a/egs/ksponspeech/ASR/zipformer/decode.py +++ b/egs/ksponspeech/ASR/zipformer/decode.py @@ -727,7 +727,11 @@ def save_results( ) with open(errs_filename, "w") as f: cer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_cers[key] = cer diff --git a/egs/ksponspeech/ASR/zipformer/streaming_decode.py b/egs/ksponspeech/ASR/zipformer/streaming_decode.py index 9811bac7c..73a681c6a 100755 --- a/egs/ksponspeech/ASR/zipformer/streaming_decode.py +++ b/egs/ksponspeech/ASR/zipformer/streaming_decode.py @@ -659,7 +659,11 @@ def save_results( ) with open(errs_filename, "w") as f: cer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_cers[key] = cer diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index 5957fe1fb..b612b6835 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -961,9 +961,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except Exception as e: - logging.info( - f"Caught exception: {e}." - ) + logging.info(f"Caught exception: {e}.") save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise From b594a3875ba9bbaeea62500b3672f06d8fe59332 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 24 Jun 2024 16:20:46 +0800 Subject: [PATCH 183/216] Add CI for non-streaming zipformer about ksponspeech (#1667) --- .github/scripts/ksponspeech/ASR/run.sh | 66 ++++++++++++++++++++++++-- .github/workflows/ksponspeech.yml | 61 +++++++++++++++++++++--- 2 files changed, 118 insertions(+), 9 deletions(-) diff --git a/.github/scripts/ksponspeech/ASR/run.sh b/.github/scripts/ksponspeech/ASR/run.sh index 068c22dfc..5c7886463 100755 --- a/.github/scripts/ksponspeech/ASR/run.sh +++ b/.github/scripts/ksponspeech/ASR/run.sh @@ -11,7 +11,66 @@ log() { cd egs/ksponspeech/ASR -function test_pretrained() { +function test_pretrained_non_streaming() { + git lfs install + git clone https://huggingface.co/johnBamma/icefall-asr-ksponspeech-zipformer-2024-06-24 + repo=icefall-asr-ksponspeech-zipformer-2024-06-24 + pushd $repo + mkdir test_wavs + cd test_wavs + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/0.wav + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/1.wav + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/2.wav + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/3.wav + curl -SL -O https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16/resolve/main/test_wavs/trans.txt + cd ../exp + ln -s pretrained.pt epoch-99.pt + ls -lh + popd + + log 'test pretrained.py' + ./zipformer/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_5000/tokens.txt \ + --method greedy_search \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav + + log 'test export-onnx.py' + + ./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_5000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ + + ls -lh $repo/exp + + ls -lh $repo/data/lang_bpe_5000/ + + log 'test exported onnx models' + ./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_5000/tokens.txt \ + $repo/test_wavs/0.wav + + dst=/tmp/model-2024-06-24 + mkdir -p $dst + + cp -av $repo/test_wavs $dst + cp -v $repo/exp/*.onnx $dst + cp -v $repo/exp/*.onnx $dst + cp -v $repo/data/lang_bpe_5000/tokens.txt $dst + cp -v $repo/data/lang_bpe_5000/bpe.model $dst + rm -rf $repo +} + +function test_pretrained_streaming() { git lfs install git clone https://huggingface.co/johnBamma/icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12 repo=icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12 @@ -59,7 +118,7 @@ function test_pretrained() { --tokens $repo/data/lang_bpe_5000/tokens.txt \ $repo/test_wavs/0.wav - dst=/tmp/model1 + dst=/tmp/model-2024-06-16 mkdir -p $dst cp -v $repo/exp/*.onnx $dst @@ -69,4 +128,5 @@ function test_pretrained() { rm -rf $repo } -test_pretrained +test_pretrained_non_streaming +test_pretrained_streaming diff --git a/.github/workflows/ksponspeech.yml b/.github/workflows/ksponspeech.yml index 2e1441c06..6c4fc546d 100644 --- a/.github/workflows/ksponspeech.yml +++ b/.github/workflows/ksponspeech.yml @@ -57,13 +57,19 @@ jobs: .github/scripts/ksponspeech/ASR/run.sh - - name: Show model files + - name: Show model files (2024-06-24) shell: bash run: | - src=/tmp/model1 + src=/tmp/model-2024-06-24 ls -lh $src - - name: Upload model to huggingface + - name: Show model files (2024-06-16) + shell: bash + run: | + src=/tmp/model-2024-06-16 + ls -lh $src + + - name: Upload model to huggingface (2024-06-24) env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 @@ -72,7 +78,41 @@ jobs: timeout_seconds: 200 shell: bash command: | - src=/tmp/model1 + src=/tmp/model-2024-06-24 + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + rm -rf hf + export GIT_LFS_SKIP_SMUDGE=1 + export GIT_CLONE_PROTECTION_ACTIVE=false + + git clone https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-korean-2024-06-24 hf + cd hf + git fetch + git pull + git merge -m "merge remote" --ff origin main + cp -av $src/* ./ + ls -lh + git lfs track "bpe.model" + git lfs track "*.onnx" + git add . + git status + git commit -m "update models" + git status + + git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-zipformer-korean-2024-06-24 main || true + rm -rf hf + + - name: Upload model to huggingface (2024-06-16) + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + src=/tmp/model-2024-06-16 git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -98,15 +138,24 @@ jobs: git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16 main || true rm -rf hf - - name: Prepare for release + - name: Prepare for release (2024-06-16) shell: bash run: | - src=/tmp/model1 + src=/tmp/model-2024-06-16 d=sherpa-onnx-streaming-zipformer-korean-2024-06-16 mv $src ./$d tar cjvf ${d}.tar.bz2 $d ls -lh + - name: Prepare for release (2024-06-24) + shell: bash + run: | + src=/tmp/model-2024-06-24 + d=sherpa-onnx-zipformer-korean-2024-06-24 + mv $src ./$d + tar cjvf ${d}.tar.bz2 $d + ls -lh + - name: Release exported onnx models uses: svenstaro/upload-release-action@v2 with: From eaab2c819f8118b0845a0094be607fa893b09981 Mon Sep 17 00:00:00 2001 From: Manix <50542248+manickavela29@users.noreply.github.com> Date: Thu, 27 Jun 2024 13:38:24 +0530 Subject: [PATCH 184/216] Zipformer Onnx FP16 (#1671) Signed-off-by: manickavela29 --- .../ASR/zipformer/export-onnx-streaming.py | 32 +++++++++++++++++-- egs/librispeech/ASR/zipformer/export-onnx.py | 30 +++++++++++++++-- requirements.txt | 1 + 3 files changed, 58 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 5d0c9ea43..e5ceb3683 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -48,7 +48,8 @@ popd --joiner-dim 512 \ --causal True \ --chunk-size 16 \ - --left-context-frames 128 + --left-context-frames 128 \ + --fp16 True The --chunk-size in training is "16,32,64,-1", so we select one of them (excluding -1) during streaming export. The same applies to `--left-context`, @@ -73,6 +74,7 @@ import onnx import torch import torch.nn as nn from decoder import Decoder +from onnxconverter_common import float16 from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -154,6 +156,13 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + add_model_arguments(parser) return parser @@ -479,7 +488,6 @@ def export_encoder_model_onnx( add_meta_data(filename=encoder_filename, meta_data=meta_data) - def export_decoder_model_onnx( decoder_model: OnnxDecoder, decoder_filename: str, @@ -747,11 +755,29 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + if(params.fp16) : + logging.info("Generate fp16 models") + + encoder = onnx.load(encoder_filename) + encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + onnx.save(encoder_fp16,encoder_filename_fp16) + + decoder = onnx.load(decoder_filename) + decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) + decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" + onnx.save(decoder_fp16,decoder_filename_fp16) + + joiner = onnx.load(joiner_filename) + joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) + joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" + onnx.save(joiner_fp16,joiner_filename_fp16) + # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection logging.info("Generate int8 quantization models") - + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" quantize_dynamic( model_input=encoder_filename, diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 3682f0b62..ed8a0ef0f 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -48,8 +48,8 @@ popd --joiner-dim 512 \ --causal False \ --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" - + --left-context-frames "64,128,256,-1" \ + --fp16 True It will generate the following 3 files inside $repo/exp: - encoder-epoch-99-avg-1.onnx @@ -70,6 +70,7 @@ import onnx import torch import torch.nn as nn from decoder import Decoder +from onnxconverter_common import float16 from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -151,6 +152,13 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + add_model_arguments(parser) return parser @@ -584,6 +592,24 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + if(params.fp16) : + logging.info("Generate fp16 models") + + encoder = onnx.load(encoder_filename) + encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + onnx.save(encoder_fp16,encoder_filename_fp16) + + decoder = onnx.load(decoder_filename) + decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) + decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" + onnx.save(decoder_fp16,decoder_filename_fp16) + + joiner = onnx.load(joiner_filename) + joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) + joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" + onnx.save(joiner_fp16,joiner_filename_fp16) + # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection diff --git a/requirements.txt b/requirements.txt index 226adaba1..d97263142 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ onnx>=1.15.0 onnxruntime>=1.16.3 onnxoptimizer onnxsim +onnxconverter_common # style check session: black==22.3.0 From ebbd396c2bbe8f2bf626fef4e3778c32d28dc301 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Wed, 3 Jul 2024 19:55:12 +0800 Subject: [PATCH 185/216] update multi-hans-zh whisper-qwen-7b results (#1677) * update qwen-7b whisper encoder results * update qwen-7b whisper encoder results * fix typo --- egs/speech_llm/ASR_LLM/RESULTS.md | 55 ++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md index 05c0ffd27..830c70397 100644 --- a/egs/speech_llm/ASR_LLM/RESULTS.md +++ b/egs/speech_llm/ASR_LLM/RESULTS.md @@ -6,27 +6,30 @@ |-| -------------------------| ----------------|------|---------------| |[yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| | [yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B) |Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample|| +| [yuekai/icefall_asr_multi-hans_whisper_qwen2_7B](https://huggingface.co/yuekai/icefall_asr_multi-hans_whisper_qwen2_7B) |Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-7B-Instruct, LoRA | Linear, 8x downsample|| CER Details: -| Model | [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | [yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B) | -|-------|------------------------------------------------|----------------------------------------------------| -| Split | Greedy Search | Greedy Search | -| aishell-1 dev | - | 0.66 | -| aishell-1 test | 3.62 | 0.68 | -| aishell-2 dev | - | 2.67 | -| aishell-2 test | - | 2.94 | -| aishell-4 test | - | 16.20 | -| alimeeting eval | - | 30.86 | -| alimeeting test | - | 40.50 | -| magicdata dev | - | 2.50 | -| magicdata test | - | 1.70 | -| kespeech-asr dev phase1 | - | 6.22 | -| kespeech-asr dev phase2 | - | 2.18 | -| kespeech-asr test | - | 6.59 | -| WenetSpeech dev | - | 4.59 | -| WenetSpeech test_meeting | - | 6.41 | -| WenetSpeech tes_net | - | 6.63 | -| SPEECHIO Avg 001-026 | - | 4.80 | +| Model | [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | [yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_multi-hans_whisper_qwen2_1.5B) | [yuekai/icefall_asr_multi-hans_whisper_qwen2_7B](https://huggingface.co/yuekai/icefall_asr_multi-hans_whisper_qwen2_7B) | +|-------|------------------------------------------------|----------------------------------------------------|-| +| Split | Greedy Search | Greedy Search | Greedy Search | +| aishell-1 dev | - | 0.66 | 0.49| +| aishell-1 test | 3.62 | 0.68 | 0.51 | +| aishell-2 dev | - | 2.67 | 2.61 | +| aishell-2 test | - | 2.94 | 2.76 | +| aishell-4 test | - | 16.20 | 15.82 | +| alimeeting eval | - | 30.86 | 29.27 | +| alimeeting test | - | 40.50 | 39.48 | +| magicdata dev | - | 2.50 | 2.27 | +| magicdata test | - | 1.70 | 1.57 | +| kespeech-asr dev phase1 | - | 6.22 | 4.87 | +| kespeech-asr dev phase2 | - | 2.18 | 1.87 | +| kespeech-asr test | - | 6.59 | 5.76 | +| WenetSpeech dev | - | 4.59 | 4.41 | +| WenetSpeech test_meeting | - | 6.41 | 6.06 | +| WenetSpeech tes_net | - | 6.63 | 6.30 | +| SPEECHIO Avg 001-026 | - | 4.80 | 4.50 | + + Command for training is: ```bash pip install -r whisper_llm_zh/requirements.txt @@ -42,6 +45,19 @@ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishel # huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct +# First, we only train the projector and freeze other modules. +torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ + --max-duration 200 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora False --unfreeze-llm False + +# Then we jointly train the projector and LLM LoRA modules. torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --max-duration 200 \ --exp-dir ./whisper_llm_zh/exp_test \ @@ -52,6 +68,7 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ --use-lora True --unfreeze-llm True + --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt ``` Command for decoding using fine-tuned models: From cbcac23d2617ccfdc8f1ecc14a00ba96413c3bf9 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 4 Jul 2024 14:19:45 +0800 Subject: [PATCH 186/216] Fix typos, remove unused packages, normalize comments (#1678) --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- egs/librispeech/ASR/zipformer/decode.py | 1 - egs/librispeech/ASR/zipformer/export-onnx-ctc.py | 1 - egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py | 3 --- egs/librispeech/ASR/zipformer/joiner.py | 2 +- egs/librispeech/ASR/zipformer/onnx_check.py | 2 -- egs/librispeech/ASR/zipformer/optim.py | 4 ++-- egs/librispeech/ASR/zipformer/pretrained.py | 2 -- egs/librispeech/ASR/zipformer/scaling_converter.py | 5 +++-- egs/librispeech/ASR/zipformer/train.py | 4 ---- 10 files changed, 7 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 814390ad6..1b52aa8b5 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -1,4 +1,4 @@ -# Copyright 2021 Piotr Żelasko +# Copyright 2021 Piotr Żelasko # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 339e253e6..df2d555a0 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -133,7 +133,6 @@ from icefall.checkpoint import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - make_pad_mask, setup_logger, store_transcripts, str2bool, diff --git a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py index 3345d20d3..99685f2fe 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py @@ -72,7 +72,6 @@ import k2 import onnx import torch import torch.nn as nn -from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py index eade5a854..c13c4ccc8 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py @@ -40,15 +40,12 @@ Usage of this script: import argparse import logging -import math from typing import List, Optional import k2 -import kaldifeat import torch import torchaudio from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature -from torch.nn.utils.rnn import pad_sequence def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index dfb0a0057..0406efe83 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -45,7 +45,7 @@ class Joiner(nn.Module): Output from the encoder. Its shape is (N, T, s_range, C). decoder_out: Output from the decoder. Its shape is (N, T, s_range, C). - project_input: + project_input: If true, apply input projections encoder_proj and decoder_proj. If this is false, it is the user's responsibility to do this manually. diff --git a/egs/librispeech/ASR/zipformer/onnx_check.py b/egs/librispeech/ASR/zipformer/onnx_check.py index 93bd3a211..b558a5dfc 100755 --- a/egs/librispeech/ASR/zipformer/onnx_check.py +++ b/egs/librispeech/ASR/zipformer/onnx_check.py @@ -82,8 +82,6 @@ import logging import torch from onnx_pretrained import OnnxModel -from icefall import is_module_available - def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index aaffbfed5..6f5180e29 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -1,4 +1,4 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../LICENSE for clarification regarding multiple authors # @@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch from lhotse.utils import fix_random_seed -from torch import Tensor, nn +from torch import Tensor from torch.optim import Optimizer diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index de0652893..9f3571b08 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -126,8 +126,6 @@ from export import num_tokens from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params -from icefall.utils import make_pad_mask - def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 76622fa12..1f95648a0 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -1,4 +1,5 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -22,7 +23,7 @@ BasicNorm is replaced by a module with `exp` removed. """ import copy -from typing import List, Tuple +from typing import List import torch import torch.nn as nn diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 04caf2fd8..858f845dc 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -512,10 +512,6 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warmup period that dictates the decay of the scale on "simple" (un-pruned) loss. """ From f76afff74144a0f57b7f1fc09e051bf058d1f1d2 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 5 Jul 2024 20:19:18 +0800 Subject: [PATCH 187/216] Support CTC/AED option for Zipformer recipe (#1389) * add attention-decoder loss option for zipformer recipe * add attention-decoder-rescoring * update export.py and pretrained_ctc.py * update RESULTS.md --- egs/librispeech/ASR/RESULTS.md | 70 +++ .../ASR/zipformer/attention_decoder.py | 573 ++++++++++++++++++ egs/librispeech/ASR/zipformer/ctc_decode.py | 91 ++- egs/librispeech/ASR/zipformer/export.py | 3 +- .../ASR/zipformer/label_smoothing.py | 109 ++++ egs/librispeech/ASR/zipformer/model.py | 30 +- .../ASR/zipformer/pretrained_ctc.py | 40 +- egs/librispeech/ASR/zipformer/train.py | 94 ++- icefall/decode.py | 232 +++++++ 9 files changed, 1221 insertions(+), 21 deletions(-) create mode 100644 egs/librispeech/ASR/zipformer/attention_decoder.py create mode 100644 egs/librispeech/ASR/zipformer/label_smoothing.py diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ee5422aba..6f00bc14d 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,75 @@ ## Results +### zipformer (zipformer + CTC/AED) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +Results of the CTC head: + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-decoding | 2.29 | 5.14 | --epoch 50 --avg 29 | +| attention-decoder-rescoring-no-ngram | 2.1 | 4.57 | --epoch 50 --avg 29 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --full-libri 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --ctc-loss-scale 0.1 \ + --attention-decoder-loss-scale 0.9 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 1200 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-decoding attention-decoder-rescoring-no-ngram; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 29 \ + --exp-dir zipformer/exp-large \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --attention-decoder-loss-scale 0.9 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 100 \ + --causal 0 \ + --num-paths 100 \ + --decoding-method $m +done +``` + + ### zipformer (zipformer + pruned stateless transducer + CTC) See for more details. diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py new file mode 100644 index 000000000..71be2d1eb --- /dev/null +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import List, Optional + +import k2 +import torch +import torch.nn as nn + +from label_smoothing import LabelSmoothingLoss +from icefall.utils import add_eos, add_sos, make_pad_mask +from scaling import penalize_abs_values_gt + + +class AttentionDecoderModel(nn.Module): + """ + Args: + vocab_size (int): Number of classes. + decoder_dim: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + num_heads (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + sos_id: int = 1, + eos_id: int = 1, + dropout: float = 0.1, + ignore_id: int = -1, + label_smoothing: float = 0.1, + ): + super().__init__() + self.eos_id = eos_id + self.sos_id = sos_id + self.ignore_id = ignore_id + + # For the segment of the warmup period, we let the Embedding + # layer learn something. Then we start to warm up the other encoders. + self.decoder = TransformerDecoder( + vocab_size=vocab_size, + d_model=decoder_dim, + num_decoder_layers=num_decoder_layers, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + + # Used to calculate attention-decoder loss + self.loss_fun = LabelSmoothingLoss( + ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum" + ) + + def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor): + """Prepare ys_in_pad and ys_out_pad.""" + ys_in = add_sos(ys, sos_id=self.sos_id) + # [B, S+1], start with SOS + ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id) + ys_in_lens = ys_lens + 1 + + ys_out = add_eos(ys, eos_id=self.eos_id) + # [B, S+1], end with EOS + ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id) + + return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64) + + def calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys: k2.RaggedTensor, + ys_lens: torch.Tensor, + ) -> torch.Tensor: + """Calculate attention-decoder loss. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: The attention-decoder loss. + """ + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + loss = self.loss_fun(x=decoder_out, target=ys_out_pad) + return loss + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + token_ids: List[List[int]], + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from attention-decoder. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: A tensor of shape (batch, num_tokens). + """ + ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device) + row_splits = ys.shape.row_splits(1) + ys_lens = row_splits[1:] - row_splits[:-1] + + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + batch_size, _, num_classes = decoder_out.size() + nll = nn.functional.cross_entropy( + decoder_out.view(-1, num_classes), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + return nll + + +class TransformerDecoder(nn.Module): + """Transfomer decoder module. + + Args: + vocab_size: output dim + d_model: decoder dimension + num_decoder_layers: number of decoder layers + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + vocab_size: int, + d_model: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + super().__init__() + self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) + + # Absolute positional encoding + self.pos = PositionalEncoding(d_model, dropout_rate=0.1) + + self.num_layers = num_decoder_layers + self.layers = nn.ModuleList( + [ + DecoderLayer( + d_model=d_model, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + for _ in range(num_decoder_layers) + ] + ) + + self.output_layer = nn.Linear(d_model, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + memory: Optional[torch.Tensor] = None, + memory_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, tgt_len). + x_lens: A tensor of shape (batch,) containing the number of tokens in `x` + before padding. + memory: + Memory sequence of shape (batch, src_len, memory_dim). + memory_lens: + A tensor of shape (batch,) containing the number of frames in + `memory` before padding. + + Returns: + Decoded token logits before softmax (batch, tgt_len, vocab_size) + """ + x = self.embed(x) # (batch, tgt_len, embed_dim) + x = self.pos(x) # (batch, tgt_len, embed_dim) + + x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim) + + # construct attn_mask for self-attn modules + padding_mask = make_pad_mask(x_lens) # (batch, tgt_len) + causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) + attn_mask = torch.logical_or( + padding_mask.unsqueeze(1), # (batch, 1, seq_len) + torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len) + ) # (batch, seq_len, seq_len) + + if memory is not None: + memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim) + # construct memory_attn_mask for cross-attn modules + memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len) + memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len) + else: + memory_attn_mask = None + + for i, mod in enumerate(self.layers): + x = mod( + x, + attn_mask=attn_mask, + memory=memory, + memory_attn_mask=memory_attn_mask, + ) + + x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size) + x = self.output_layer(x) + + return x + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + d_model: equal to decoder_dim, total dimension of the decoder + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + d_model: int = 512, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + """Construct an DecoderLayer object.""" + super(DecoderLayer, self).__init__() + + self.norm_self_attn = nn.LayerNorm(d_model) + self.self_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, dropout=0.0 + ) + + self.norm_src_attn = nn.LayerNorm(d_model) + self.src_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0 + ) + + self.norm_ff = nn.LayerNorm(d_model) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, feedforward_dim), + Swish(), + nn.Dropout(dropout), + nn.Linear(feedforward_dim, d_model), + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input sequence of shape (seq_len, batch, embed_dim). + attn_mask: A binary mask for self-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + memory: Memory sequence of shape (seq_len, batch, memory_dim). + memory_attn_mask: A binary mask for cross-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + """ + # self-attn module + qkv = self.norm_self_attn(x) + self_attn_out = self.self_attn( + query=qkv, key=qkv, value=qkv, attn_mask=attn_mask + ) + x = x + self.dropout(self_attn_out) + + # cross-attn module + q = self.norm_src_attn(x) + src_attn_out = self.src_attn( + query=q, key=memory, value=memory, attn_mask=memory_attn_mask + ) + x = x + self.dropout(src_attn_out) + + # feed-forward module + x = x + self.dropout(self.feed_forward(self.norm_ff(x))) + + return x + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, but must be a multiple of num_heads. + num_heads: number of parallel attention heads. + memory_dim: dimension of memory embedding, optional. + dropout: a Dropout layer on attn_output_weights. + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + memory_dim: Optional[int] = None, + dropout: float = 0.0, + ): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.head_dim = attention_dim // num_heads + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, num_heads, attention_dim + ) + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True) + self.linear_k = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + self.linear_v = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + + self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute dot product attention. + + Args: + query: Query tensor of shape (tgt_len, batch, embed_dim). + key: Key tensor of shape (src_len, batch, embed_dim or memory_dim). + value: Value tensor of shape (src_len, batch, embed_dim or memory_dim). + key_padding_mask: A binary mask indicating which elements are padding. + Its shape is (batch, src_len). + attn_mask: A binary mask indicating which elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + + Returns: + Output tensor of shape (tgt_len, batch, embed_dim). + """ + num_heads = self.num_heads + head_dim = self.head_dim + + tgt_len, batch, _ = query.shape + src_len = key.shape[0] + + q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim) + k = self.linear_k(key) # (src_len, batch, num_heads * head_dim) + v = self.linear_v(value) # (src_len, batch, num_heads * head_dim) + + q = q.reshape(tgt_len, batch, num_heads, head_dim) + q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim) + k = k.reshape(src_len, batch, num_heads, head_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len) + v = v.reshape(src_len, batch, num_heads, head_dim) + v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1) + + # Note: could remove the scaling operation when using ScaledAdam + # (batch, head, tgt_len, src_len) + attn_weights = torch.matmul(q, k) / math.sqrt(head_dim) + + # From zipformer.py: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), + ) + + if attn_mask is not None: + assert ( + attn_mask.shape == (batch, 1, src_len) + or attn_mask.shape == (batch, tgt_len, src_len) + ), attn_mask.shape + attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf")) + + attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + # (batch * head, tgt_len, head_dim) + attn_output = torch.bmm(attn_weights, v) + assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape + + attn_output = attn_output.transpose(0, 1).contiguous() + attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) + + # (batch, tgt_len, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class PositionalEncoding(nn.Module): + """Positional encoding. + Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def subsequent_mask(size, device="cpu", dtype=torch.bool): + """Create mask for subsequent steps (size, size). + + :param int size: size of mask + :param str device: "cpu" or "cuda" or torch.Tensor.device + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=dtype) + return torch.tril(ret, out=ret) + + +def _test_attention_decoder_model(): + m = AttentionDecoderModel( + vocab_size=500, + decoder_dim=512, + num_decoder_layers=6, + attention_dim=512, + num_heads=8, + feedforward_dim=2048, + memory_dim=384, + dropout=0.1, + sos_id=1, + eos_id=1, + ignore_id=-1, + ) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + m.eval() + encoder_out = torch.randn(2, 50, 384) + encoder_out_lens = torch.full((2,), 50) + token_ids = [[1, 2, 3, 4], [2, 3, 10]] + + nll = m.nll(encoder_out, encoder_out_lens, token_ids) + print(nll) + + +if __name__ == "__main__": + _test_attention_decoder_model() diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 1f0f9bfac..85ceb61b8 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -73,6 +73,29 @@ Usage: --nbest-scale 1.0 \ --lm-dir data/lm \ --decoding-method whole-lattice-rescoring + +(6) attention-decoder-rescoring-no-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(7) attention-decoder-rescoring-with-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram """ @@ -101,6 +124,8 @@ from icefall.decode import ( nbest_decoding, nbest_oracle, one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, rescore_with_n_best_list, rescore_with_whole_lattice, ) @@ -212,6 +237,10 @@ def get_parser(): - (6) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. + - (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + - (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + rescored lattice, rescore them with the attention decoder. """, ) @@ -406,6 +435,26 @@ def decode_one_batch( key = "ctc-decoding" return {key: hyps} + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + if params.decoding_method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -446,6 +495,7 @@ def decode_one_batch( assert params.decoding_method in [ "nbest-rescoring", "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] @@ -466,6 +516,21 @@ def decode_one_batch( G_with_epsilon_loops=G, lm_scale_list=lm_scale_list, ) + elif params.decoding_method == "attention-decoder-rescoring-with-ngram": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + best_path_dict = rescore_with_attention_decoder_with_ngram( + lattice=rescored_lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) else: assert False, f"Unsupported decoding method: {params.decoding_method}" @@ -564,12 +629,21 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): + if params.decoding_method in ( + "attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" + ): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. @@ -577,8 +651,8 @@ def save_results( with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" @@ -616,6 +690,8 @@ def main(): "nbest-rescoring", "whole-lattice-rescoring", "nbest-oracle", + "attention-decoder-rescoring-no-ngram", + "attention-decoder-rescoring-with-ngram", ) params.res_dir = params.exp_dir / params.decoding_method @@ -654,8 +730,10 @@ def main(): params.vocab_size = num_classes # and are defined in local/train_bpe_model.py params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 - if params.decoding_method == "ctc-decoding": + if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: HLG = None H = k2.ctc_topo( max_token=max_token_id, @@ -679,6 +757,7 @@ def main(): if params.decoding_method in ( "nbest-rescoring", "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") @@ -710,7 +789,9 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) G = k2.Fsa.from_dict(d) - if params.decoding_method == "whole-lattice-rescoring": + if params.decoding_method in [ + "whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" + ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index 2b8d1aaf3..1f3373cd8 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -404,6 +404,7 @@ def main(): token_table = k2.SymbolTable.from_file(params.tokens) params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] params.vocab_size = num_tokens(token_table) + 1 logging.info(params) @@ -466,8 +467,6 @@ def main(): device=device, ) ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: assert params.avg > 0, params.avg start = params.epoch - params.avg diff --git a/egs/librispeech/ASR/zipformer/label_smoothing.py b/egs/librispeech/ASR/zipformer/label_smoothing.py new file mode 100644 index 000000000..52d2eda3b --- /dev/null +++ b/egs/librispeech/ASR/zipformer/label_smoothing.py @@ -0,0 +1,109 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) + + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + ) + + # Set the value of ignored indexes to 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 86da3ab29..bd1ed26d8 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -34,11 +34,13 @@ class AsrModel(nn.Module): encoder: EncoderInterface, decoder: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, encoder_dim: int = 384, decoder_dim: int = 512, vocab_size: int = 500, use_transducer: bool = True, use_ctc: bool = False, + use_attention_decoder: bool = False, ): """A joint CTC & Transducer ASR model. @@ -70,6 +72,8 @@ class AsrModel(nn.Module): Whether use transducer head. Default: True. use_ctc: Whether use CTC head. Default: False. + use_attention_decoder: + Whether use attention-decoder head. Default: False. """ super().__init__() @@ -111,6 +115,12 @@ class AsrModel(nn.Module): nn.LogSoftmax(dim=-1), ) + self.use_attention_decoder = use_attention_decoder + if use_attention_decoder: + self.attention_decoder = attention_decoder + else: + assert attention_decoder is None + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -286,7 +296,7 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -308,7 +318,7 @@ class AsrModel(nn.Module): part Returns: Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss) + in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -322,6 +332,8 @@ class AsrModel(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + device = x.device + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -333,7 +345,7 @@ class AsrModel(nn.Module): simple_loss, pruned_loss = self.forward_transducer( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - y=y.to(x.device), + y=y.to(device), y_lens=y_lens, prune_range=prune_range, am_scale=am_scale, @@ -355,4 +367,14 @@ class AsrModel(nn.Module): else: ctc_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 408d13576..4341ef61f 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -81,6 +81,15 @@ Usage of this script: --sample-rate 16000 \ /path/to/foo.wav \ /path/to/bar.wav + +(5) attention-decoder-rescoring-no-ngram +./zipformer/pretrained_ctc.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method attention-decoder-rescoring-no-ngram \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav """ import argparse @@ -100,6 +109,7 @@ from train import add_model_arguments, get_model, get_params from icefall.decode import ( get_lattice, one_best_decoding, + rescore_with_attention_decoder_no_ngram, rescore_with_n_best_list, rescore_with_whole_lattice, ) @@ -172,6 +182,8 @@ def get_parser(): decoding lattice and then use 1best to decode the rescored lattice. We call it HLG decoding + whole-lattice n-gram LM rescoring. + (4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. """, ) @@ -276,6 +288,7 @@ def main(): token_table = k2.SymbolTable.from_file(params.tokens) params.vocab_size = num_tokens(token_table) + 1 # +1 for blank params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] assert params.blank_id == 0 logging.info(f"{params}") @@ -333,16 +346,13 @@ def main(): dtype=torch.int32, ) - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: max_token_id = params.vocab_size - 1 - H = k2.ctc_topo( max_token=max_token_id, modified=False, device=device, ) - lattice = get_lattice( nnet_output=ctc_output, decoding_graph=H, @@ -354,9 +364,23 @@ def main(): subsampling_factor=params.subsampling_factor, ) - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + else: + logging.info("Use attention decoder rescoring without ngram") + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + token_ids = get_texts(best_path) hyps = [[token_table[i] for i in ids] for ids in token_ids] elif params.method in [ @@ -430,7 +454,7 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - if params.method == "ctc-decoding": + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: for filename, hyp in zip(params.sound_files, hyps): words = "".join(hyp) words = words.replace("▁", " ").strip() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 858f845dc..3797de484 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -48,6 +48,8 @@ It supports training with: - transducer loss (default), with `--use-transducer True --use-ctc False` - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` + - ctc loss & attention decoder loss, no transducer loss, + with `--use-transducer False --use-ctc True --use-attention-decoder True` """ @@ -66,6 +68,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -221,6 +224,41 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + parser.add_argument( "--causal", type=str2bool, @@ -259,6 +297,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use CTC head.", ) + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -404,6 +449,13 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + parser.add_argument( "--seed", type=int, @@ -528,6 +580,9 @@ def get_params() -> AttributeDict: # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, "warm_step": 2000, "env_info": get_env_info(), } @@ -600,6 +655,23 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + def get_model(params: AttributeDict) -> nn.Module: assert params.use_transducer or params.use_ctc, ( f"At least one of them should be True, " @@ -617,16 +689,23 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, + attention_decoder=attention_decoder, encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, use_transducer=params.use_transducer, use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, ) return model @@ -789,7 +868,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -819,6 +898,9 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -833,6 +915,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() return loss, info @@ -1112,10 +1196,16 @@ def run(rank, world_size, args): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() if not params.use_transducer: - params.ctc_loss_scale = 1.0 + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, params.attention_decoder_loss_scale + ) logging.info(params) diff --git a/icefall/decode.py b/icefall/decode.py index 23f9fb9b3..3abd5648a 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1083,6 +1083,238 @@ def rescore_with_attention_decoder( return ans +def rescore_with_attention_decoder_with_ngram( + lattice: k2.Fsa, + num_paths: int, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + nbest_scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + attention_decoder: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + encoder_out: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(N, T, C)`. + encoder_out_lens: + Length of encoder outputs, with shape of `(N,)`. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + max_loop_count = 10 + loop_count = 0 + while loop_count <= max_loop_count: + try: + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # nbest.fsa.scores are all 0s at this point + nbest = nbest.intersect(lattice) + break + except RuntimeError as e: + logging.info(f"Caught exception:\n{e}\n") + logging.info(f"num_paths before decreasing: {num_paths}") + num_paths = int(num_paths / 2) + if loop_count >= max_loop_count or num_paths <= 0: + logging.info("Return None as the resulting lattice is too large.") + return None + logging.info( + "This OOM is not an error. You can ignore it. " + "If your model does not converge well, or --max-duration " + "is too large, or the input sound file is difficult to " + "decode, you will meet this exception." + ) + logging.info(f"num_paths after decreasing: {num_paths}") + loop_count += 1 + + # Now nbest.fsa has its scores set. + # Also, nbest.fsa inherits the attributes from `lattice`. + assert hasattr(nbest.fsa, "lm_scores") + + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() + + # The `tokens` attribute is set inside `compile_hlg.py` + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) + + path_to_utt_map = nbest.shape.row_ids(1).to(torch.long) + # the shape of memory is (T, N, C), so we use axis=1 here + expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map) + + # remove axis corresponding to states. + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + token_ids = tokens.tolist() + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if ngram_lm_scale is None: + ngram_lm_scale_list = [0.01, 0.05, 0.08] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + else: + ngram_lm_scale_list = [ngram_lm_scale] + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + for n_scale in ngram_lm_scale_list: + for a_scale in attention_scale_list: + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores.values + + a_scale * attention_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" + ans[key] = best_path + return ans + + +def rescore_with_attention_decoder_no_ngram( + lattice: k2.Fsa, + num_paths: int, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + nbest_scale: float = 1.0, + attention_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + attention_decoder: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + encoder_out: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(N, T, C)`. + encoder_out_lens: + Length of encoder outputs, with shape of `(N,)`. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + # path is a ragged tensor with dtype torch.int32. + # It has three axes [utt][path][arc_pos] + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + # Note that labels, aux_labels and scores contains 0s and -1s. + # The last entry in each sublist is -1. + # The axes are [path][token_id] + labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0) + aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0) + scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0) + + # Remove -1 from labels as we will use it to construct a linear FSA + labels = labels.remove_values_eq(-1) + fsa = k2.linear_fsa(labels) + fsa.aux_labels = aux_labels.values + + # utt_to_path_shape has axes [utt][path] + utt_to_path_shape = path.shape.get_layer(0) + scores = k2.RaggedTensor(utt_to_path_shape, scores.sum()) + + path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long) + # the shape of memory is (N, T, C), so we use axis=0 here + expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map) + + token_ids = aux_labels.remove_values_leq(0).tolist() + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + + for a_scale in attention_scale_list: + tot_scores = scores.values + a_scale * attention_scores + ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(fsa, max_indexes) + + key = f"attention_scale_{a_scale}" + ans[key] = best_path + return ans + + def rescore_with_rnn_lm( lattice: k2.Fsa, num_paths: int, From 325a825841154cc0d652141ee773ee1dff3a581a Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 6 Jul 2024 09:01:19 +0800 Subject: [PATCH 188/216] Update requirements-ci.txt (#1682) --- requirements-ci.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-ci.txt b/requirements-ci.txt index ebea04615..59ab84e8c 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -28,5 +28,6 @@ multi_quantization onnx onnxmltools onnxruntime +onnxconverter_common kaldifst kaldi-decoder From 2d64228efae6ebb85b68a956942374a4801548ae Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 6 Jul 2024 09:01:34 +0800 Subject: [PATCH 189/216] Update attention_decoder.py (#1681) --- egs/librispeech/ASR/zipformer/attention_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index 71be2d1eb..81682e87b 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -22,11 +22,11 @@ from typing import List, Optional import k2 import torch import torch.nn as nn - from label_smoothing import LabelSmoothingLoss -from icefall.utils import add_eos, add_sos, make_pad_mask from scaling import penalize_abs_values_gt +from icefall.utils import add_eos, add_sos, make_pad_mask + class AttentionDecoderModel(nn.Module): """ From 1c3d992a3989a4a4dc73b9fdb43e0a1b39038aee Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Tue, 9 Jul 2024 09:57:52 +0800 Subject: [PATCH 190/216] Update results using Zipformer-large on multi-hans-zh (#1679) --- egs/multi_zh-hans/ASR/RESULTS.md | 55 +++ .../ASR/whisper/multi_dataset.py | 248 +------------- egs/multi_zh-hans/ASR/zipformer/ctc_decode.py | 13 +- egs/multi_zh-hans/ASR/zipformer/decode.py | 6 +- .../ASR/zipformer/multi_dataset.py | 317 +----------------- egs/multi_zh-hans/ASR/zipformer/train.py | 40 +++ egs/speechio/ASR/local/normalize_results.py | 3 +- 7 files changed, 107 insertions(+), 575 deletions(-) mode change 100644 => 120000 egs/multi_zh-hans/ASR/whisper/multi_dataset.py mode change 100644 => 120000 egs/multi_zh-hans/ASR/zipformer/multi_dataset.py diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index e411e80a3..6f5c93ea9 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -43,6 +43,61 @@ Fine-tuned models, training logs, decoding logs, tensorboard and decoding result are available at +### Multi Chinese datasets char-based training results (streaming) on zipformer large model + +#### Streaming (with CTC head) + +The training command for large model (num of params : ~160M): + +Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features. + +``` +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 20 \ + --use-fp16 1 \ + --max-duration 1200 \ + --num-workers 8 \ + --use-ctc 1 \ + --exp-dir zipformer/exp-large \ + --causal 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 768,1024,1536,2048,1536,768 \ + --encoder-dim 256,384,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 + +``` + +The decoding command for transducer greedy search: + +``` +./zipformer/decode.py \ + --epoch 999 \ + --avg 1 \ + --causal 1 \ + --use-averaged-model False \ + --chunk_size -1 + --left-context-frames -1 \ + --use-ctc 1 \ + --exp-dir zipformer/exp-large \ + --max-duration 1200 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 768,1024,1536,2048,1536,768 \ + --encoder-dim 256,384,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 +``` + +Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled). + +| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| +| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| CTC Greedy Streaming | 26.50 | 28.10| 1.71 | 1.97| 3.89| 4.06 | 17.23 | 3.69 | 2.87 | 8.14 | 3.61 |9.51 | 6.11 | 8.13 | 10.62 | +| CTC Greedy Offline | 23.47 | 25.02 | 1.39 | 1.50 | 3.15 | 3.41 | 15.14 | 3.07 | 2.37 | 6.06 | 2.90 | 7.13 | 5.40 | 6.52 | 9.64 | +| Transducer Greedy Offline | 23.16 | 24.78 | 1.33 | 1.38 | 3.06 | 3.23 | 15.36 | 2.54 | 2.09 | 5.24 | 2.28 | 6.26 | 4.87 | 6.26 | 7.07 | +| Transducer Greedy Streaming | 26.83|28.74 | 1.75 | 1.91 | 3.84 | 4.12 | 17.83 | 3.23 | 2.71 | 7.31 | 3.16 | 8.69 | 5.71 | 7.91 | 8.54 | + +Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-large ### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model diff --git a/egs/multi_zh-hans/ASR/whisper/multi_dataset.py b/egs/multi_zh-hans/ASR/whisper/multi_dataset.py deleted file mode 100644 index d0054c4f7..000000000 --- a/egs/multi_zh-hans/ASR/whisper/multi_dataset.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import glob -import logging -import re -from pathlib import Path -from typing import Dict, List - -import lhotse -from lhotse import CutSet, load_manifest_lazy - - -class MultiDataset: - def __init__(self, fbank_dir: str): - """ - Args: - manifest_dir: - It is expected to contain the following files: - - aishell_cuts_train.jsonl.gz - - aishell2_cuts_train.jsonl.gz - - aishell4_cuts_train_L.jsonl.gz - - aishell4_cuts_train_M.jsonl.gz - - aishell4_cuts_train_S.jsonl.gz - - alimeeting-far_cuts_train.jsonl.gz - - magicdata_cuts_train.jsonl.gz - - primewords_cuts_train.jsonl.gz - - stcmds_cuts_train.jsonl.gz - - thchs_30_cuts_train.jsonl.gz - - kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz - - kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz - - wenetspeech/cuts_L_fixed.jsonl.gz - """ - self.fbank_dir = Path(fbank_dir) - - def train_cuts(self) -> CutSet: - logging.info("About to get multidataset train cuts") - - # THCHS-30 - logging.info("Loading THCHS-30 in lazy mode") - thchs_30_cuts = load_manifest_lazy( - self.fbank_dir / "thchs_30_cuts_train.jsonl.gz" - ) - - # AISHELL-1 - logging.info("Loading Aishell-1 in lazy mode") - aishell_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_train.jsonl.gz" - ) - - # AISHELL-2 - logging.info("Loading Aishell-2 in lazy mode") - aishell_2_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_train.jsonl.gz" - ) - - # AISHELL-4 - logging.info("Loading Aishell-4 in lazy mode") - aishell_4_L_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz" - ) - aishell_4_M_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz" - ) - aishell_4_S_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz" - ) - - # ST-CMDS - logging.info("Loading ST-CMDS in lazy mode") - stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz") - - # Primewords - logging.info("Loading Primewords in lazy mode") - primewords_cuts = load_manifest_lazy( - self.fbank_dir / "primewords_cuts_train.jsonl.gz" - ) - - # MagicData - logging.info("Loading MagicData in lazy mode") - magicdata_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_train.jsonl.gz" - ) - - # Ali-Meeting - logging.info("Loading Ali-Meeting in lazy mode") - alimeeting_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz" - ) - - # WeNetSpeech - logging.info("Loading WeNetSpeech in lazy mode") - wenetspeech_L_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz" - ) - - # KeSpeech - logging.info("Loading KeSpeech in lazy mode") - kespeech_1_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz" - ) - kespeech_2_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz" - ) - - return CutSet.mux( - thchs_30_cuts, - aishell_cuts, - aishell_2_cuts, - aishell_4_L_cuts, - aishell_4_M_cuts, - aishell_4_S_cuts, - alimeeting_cuts, - stcmds_cuts, - primewords_cuts, - magicdata_cuts, - wenetspeech_L_cuts, - kespeech_1_cuts, - kespeech_2_cuts, - weights=[ - len(thchs_30_cuts), - len(aishell_cuts), - len(aishell_2_cuts), - len(aishell_4_L_cuts), - len(aishell_4_M_cuts), - len(aishell_4_S_cuts), - len(alimeeting_cuts), - len(stcmds_cuts), - len(primewords_cuts), - len(magicdata_cuts), - len(wenetspeech_L_cuts), - len(kespeech_1_cuts), - len(kespeech_2_cuts), - ], - ) - - def dev_cuts(self) -> CutSet: - logging.info("About to get multidataset dev cuts") - - # WeNetSpeech - logging.info("Loading WeNetSpeech DEV set in lazy mode") - wenetspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz" - ) - - return wenetspeech_dev_cuts - - def test_cuts(self) -> Dict[str, CutSet]: - logging.info("About to get multidataset test cuts") - - # AISHELL - logging.info("Loading Aishell set in lazy mode") - aishell_test_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_test.jsonl.gz" - ) - aishell_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_dev.jsonl.gz" - ) - - # AISHELL-2 - logging.info("Loading Aishell-2 set in lazy mode") - aishell2_test_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_test.jsonl.gz" - ) - aishell2_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" - ) - - # AISHELL-4 - logging.info("Loading Aishell-4 TEST set in lazy mode") - aishell4_test_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_test.jsonl.gz" - ) - - # Ali-Meeting - logging.info("Loading Ali-Meeting set in lazy mode") - alimeeting_test_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz" - ) - alimeeting_eval_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" - ) - - # MagicData - logging.info("Loading MagicData set in lazy mode") - magicdata_test_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_test.jsonl.gz" - ) - magicdata_dev_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" - ) - - # KeSpeech - logging.info("Loading KeSpeech set in lazy mode") - kespeech_test_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz" - ) - kespeech_dev_phase1_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" - ) - kespeech_dev_phase2_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" - ) - - # WeNetSpeech - logging.info("Loading WeNetSpeech set in lazy mode") - wenetspeech_test_meeting_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" - ) - wenetspeech_test_net_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz" - ) - wenetspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz" - ) - - return { - "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, - # "aishell_test": aishell_test_cuts, - # "aishell_dev": aishell_dev_cuts, - # "ali-meeting_test": alimeeting_test_cuts, - # "ali-meeting_eval": alimeeting_eval_cuts, - # "aishell-4_test": aishell4_test_cuts, - # "aishell-2_test": aishell2_test_cuts, - # "aishell-2_dev": aishell2_dev_cuts, - # "magicdata_test": magicdata_test_cuts, - # "magicdata_dev": magicdata_dev_cuts, - # "kespeech-asr_test": kespeech_test_cuts, - # "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, - # "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, - # "wenetspeech-net_test": wenetspeech_test_net_cuts, - # "wenetspeech_dev": wenetspeech_dev_cuts, - } diff --git a/egs/multi_zh-hans/ASR/whisper/multi_dataset.py b/egs/multi_zh-hans/ASR/whisper/multi_dataset.py new file mode 120000 index 000000000..d2e14a1ad --- /dev/null +++ b/egs/multi_zh-hans/ASR/whisper/multi_dataset.py @@ -0,0 +1 @@ +../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py index 5143f945d..8d4a81fb0 100755 --- a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py @@ -46,7 +46,7 @@ import torch.nn as nn from asr_datamodule import AsrDataModule from lhotse.cut import Cut from multi_dataset import MultiDataset -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting from icefall.checkpoint import ( average_checkpoints, @@ -367,21 +367,18 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, - HLG=HLG, H=H, bpe_model=bpe_model, batch=batch, - word_table=word_table, - G=G, ) for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = list(ref_text.replace(" ", "")) - hyp_words = list("".join(hyp_words)) - this_batch.append((cut_id, ref_words, hyp_words)) + ref_text = normalize_text_alimeeting(ref_text) + hyp_text = "".join(hyp_words) + this_batch.append((cut_id, ref_text, hyp_text)) results[name].extend(this_batch) @@ -583,7 +580,7 @@ def main(): data_module = AsrDataModule(args) multi_dataset = MultiDataset(args.manifest_dir) - test_sets_cuts = multi_dataset.test_cuts() + test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()} def remove_short_utt(c: Cut): T = ((c.num_frames - 7) // 2 + 1) // 2 diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index f501c3c30..5993f243f 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -118,7 +118,7 @@ from beam_search import ( ) from lhotse.cut import Cut from multi_dataset import MultiDataset -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting from icefall.checkpoint import ( average_checkpoints, @@ -532,7 +532,6 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [list(str(text).replace(" ", "")) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -548,6 +547,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_text = normalize_text_alimeeting(ref_text) hyp_text = "".join(hyp_words) this_batch.append((cut_id, ref_text, hyp_text)) @@ -795,7 +795,7 @@ def main(): ) return T > 0 - test_sets_cuts = multi_dataset.test_cuts() + test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()} test_sets = test_sets_cuts.keys() test_dl = [ diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py deleted file mode 100644 index b1920e62e..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import glob -import logging -import re -from pathlib import Path -from typing import Dict, List - -import lhotse -from lhotse import CutSet, load_manifest_lazy - - -class MultiDataset: - def __init__(self, fbank_dir: str): - """ - Args: - manifest_dir: - It is expected to contain the following files: - - aidatatang_cuts_train.jsonl.gz - - aishell_cuts_train.jsonl.gz - - aishell2_cuts_train.jsonl.gz - - aishell4_cuts_train_L.jsonl.gz - - aishell4_cuts_train_M.jsonl.gz - - aishell4_cuts_train_S.jsonl.gz - - alimeeting-far_cuts_train.jsonl.gz - - magicdata_cuts_train.jsonl.gz - - primewords_cuts_train.jsonl.gz - - stcmds_cuts_train.jsonl.gz - - thchs_30_cuts_train.jsonl.gz - - kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz - - kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz - - wenetspeech/cuts_L.jsonl.gz - """ - self.fbank_dir = Path(fbank_dir) - - def train_cuts(self) -> CutSet: - logging.info("About to get multidataset train cuts") - - # THCHS-30 - logging.info("Loading THCHS-30 in lazy mode") - thchs_30_cuts = load_manifest_lazy( - self.fbank_dir / "thchs_30_cuts_train.jsonl.gz" - ) - - # AISHELL-1 - logging.info("Loading Aishell-1 in lazy mode") - aishell_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_train.jsonl.gz" - ) - - # AISHELL-2 - logging.info("Loading Aishell-2 in lazy mode") - aishell_2_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_train.jsonl.gz" - ) - - # AISHELL-4 - logging.info("Loading Aishell-4 in lazy mode") - aishell_4_L_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz" - ) - aishell_4_M_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz" - ) - aishell_4_S_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz" - ) - - # ST-CMDS - logging.info("Loading ST-CMDS in lazy mode") - stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz") - - # Primewords - logging.info("Loading Primewords in lazy mode") - primewords_cuts = load_manifest_lazy( - self.fbank_dir / "primewords_cuts_train.jsonl.gz" - ) - - # MagicData - logging.info("Loading MagicData in lazy mode") - magicdata_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_train.jsonl.gz" - ) - - # Aidatatang_200zh - logging.info("Loading Aidatatang_200zh in lazy mode") - aidatatang_200zh_cuts = load_manifest_lazy( - self.fbank_dir / "aidatatang_cuts_train.jsonl.gz" - ) - - # Ali-Meeting - logging.info("Loading Ali-Meeting in lazy mode") - alimeeting_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz" - ) - - # WeNetSpeech - logging.info("Loading WeNetSpeech in lazy mode") - wenetspeech_L_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz" - ) - - # KeSpeech - logging.info("Loading KeSpeech in lazy mode") - kespeech_1_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz" - ) - kespeech_2_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz" - ) - - return CutSet.mux( - thchs_30_cuts, - aishell_cuts, - aishell_2_cuts, - aishell_4_L_cuts, - aishell_4_M_cuts, - aishell_4_S_cuts, - stcmds_cuts, - primewords_cuts, - magicdata_cuts, - aidatatang_200zh_cuts, - alimeeting_cuts, - wenetspeech_L_cuts, - kespeech_1_cuts, - kespeech_2_cuts, - weights=[ - len(thchs_30_cuts), - len(aishell_cuts), - len(aishell_2_cuts), - len(aishell_4_L_cuts), - len(aishell_4_M_cuts), - len(aishell_4_S_cuts), - len(stcmds_cuts), - len(primewords_cuts), - len(magicdata_cuts), - len(aidatatang_200zh_cuts), - len(alimeeting_cuts), - len(wenetspeech_L_cuts), - len(kespeech_1_cuts), - len(kespeech_2_cuts), - ], - ) - - def dev_cuts(self) -> CutSet: - logging.info("About to get multidataset dev cuts") - - # Aidatatang_200zh - logging.info("Loading Aidatatang_200zh DEV set in lazy mode") - aidatatang_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz" - ) - - # AISHELL - logging.info("Loading Aishell DEV set in lazy mode") - aishell_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_dev.jsonl.gz" - ) - - # AISHELL-2 - logging.info("Loading Aishell-2 DEV set in lazy mode") - aishell2_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" - ) - - # Ali-Meeting - logging.info("Loading Ali-Meeting DEV set in lazy mode") - alimeeting_dev_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" - ) - - # MagicData - logging.info("Loading MagicData DEV set in lazy mode") - magicdata_dev_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" - ) - - # KeSpeech - logging.info("Loading KeSpeech DEV set in lazy mode") - kespeech_dev_phase1_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" - ) - kespeech_dev_phase2_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" - ) - - # WeNetSpeech - logging.info("Loading WeNetSpeech DEV set in lazy mode") - wenetspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" - ) - - return wenetspeech_dev_cuts - # return [ - # aidatatang_dev_cuts, - # aishell_dev_cuts, - # aishell2_dev_cuts, - # alimeeting_dev_cuts, - # magicdata_dev_cuts, - # kespeech_dev_phase1_cuts, - # kespeech_dev_phase2_cuts, - # wenetspeech_dev_cuts, - # ] - - def test_cuts(self) -> Dict[str, CutSet]: - logging.info("About to get multidataset test cuts") - - # Aidatatang_200zh - logging.info("Loading Aidatatang_200zh set in lazy mode") - aidatatang_test_cuts = load_manifest_lazy( - self.fbank_dir / "aidatatang_cuts_test.jsonl.gz" - ) - aidatatang_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz" - ) - - # AISHELL - logging.info("Loading Aishell set in lazy mode") - aishell_test_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_test.jsonl.gz" - ) - aishell_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_dev.jsonl.gz" - ) - - # AISHELL-2 - logging.info("Loading Aishell-2 set in lazy mode") - aishell2_test_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_test.jsonl.gz" - ) - aishell2_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" - ) - - # AISHELL-4 - logging.info("Loading Aishell-4 TEST set in lazy mode") - aishell4_test_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_test.jsonl.gz" - ) - - # Ali-Meeting - logging.info("Loading Ali-Meeting set in lazy mode") - alimeeting_test_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz" - ) - alimeeting_eval_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" - ) - - # MagicData - logging.info("Loading MagicData set in lazy mode") - magicdata_test_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_test.jsonl.gz" - ) - magicdata_dev_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" - ) - - # KeSpeech - logging.info("Loading KeSpeech set in lazy mode") - kespeech_test_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz" - ) - kespeech_dev_phase1_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" - ) - kespeech_dev_phase2_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" - ) - - # WeNetSpeech - logging.info("Loading WeNetSpeech set in lazy mode") - wenetspeech_test_meeting_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" - ) - wenetspeech_test_net_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz" - ) - wenetspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" - ) - - return { - "aidatatang_test": aidatatang_test_cuts, - "aidatatang_dev": aidatatang_dev_cuts, - "alimeeting_test": alimeeting_test_cuts, - "alimeeting_eval": alimeeting_eval_cuts, - "aishell_test": aishell_test_cuts, - "aishell_dev": aishell_dev_cuts, - "aishell-2_test": aishell2_test_cuts, - "aishell-2_dev": aishell2_dev_cuts, - "aishell-4": aishell4_test_cuts, - "magicdata_test": magicdata_test_cuts, - "magicdata_dev": magicdata_dev_cuts, - "kespeech-asr_test": kespeech_test_cuts, - "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, - "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, - "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, - "wenetspeech-net_test": wenetspeech_test_net_cuts, - "wenetspeech_dev": wenetspeech_dev_cuts, - } diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py new file mode 120000 index 000000000..d2e14a1ad --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -0,0 +1 @@ +../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 447ca122f..1fc4c35c1 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -539,6 +539,43 @@ def get_params() -> AttributeDict: return params +def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + if normalize == "none": + return text + elif normalize == "m2met": + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + + def _to_int_tuple(s: str): return tuple(map(int, s.split(","))) @@ -788,6 +825,9 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] + # remove spaces in texts + texts = [normalize_text_alimeeting(text) for text in texts] + y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) diff --git a/egs/speechio/ASR/local/normalize_results.py b/egs/speechio/ASR/local/normalize_results.py index 79d886617..02277e2a8 100755 --- a/egs/speechio/ASR/local/normalize_results.py +++ b/egs/speechio/ASR/local/normalize_results.py @@ -114,7 +114,8 @@ def extract_hyp_ref_wavname(filename): for line in f: if "ref" in line: ref = line.split("ref=")[1].strip() - ref = ref[2:-2] + if ref[0] == "[": + ref = ref[2:-2] list_elements = ref.split("', '") ref = "".join(list_elements) refs.append(ref) From 785f3f0bcf3d65d3f00232e9c084c78d4bff5bc1 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Tue, 9 Jul 2024 20:04:47 +0800 Subject: [PATCH 191/216] Update RESULTS.md, adding results and model links of zipformer-small/medium CTC/AED models (#1683) --- egs/librispeech/ASR/README.md | 2 +- egs/librispeech/ASR/RESULTS.md | 113 ++++++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 93fef7a07..8b87ee19b 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer. | `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | | `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | | `zipformer-ctc` | Zipformer | Use auxiliary attention head | -| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head | The latest recipe | +| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe | # MMI diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 6f00bc14d..66b147764 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -8,6 +8,117 @@ See for more details. #### Non-streaming +##### small-scale model, number of model parameters: 46282107, i.e., 46.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-decoding | 3.04 | 7.04 | --epoch 50 --avg 30 | +| attention-decoder-rescoring-no-ngram | 2.45 | 6.08 | --epoch 50 --avg 30 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-small \ + --full-libri 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --ctc-loss-scale 0.1 \ + --attention-decoder-loss-scale 0.9 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --base-lr 0.04 \ + --max-duration 1700 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-decoding attention-decoder-rescoring-no-ngram; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 30 \ + --exp-dir zipformer/exp-small \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --attention-decoder-loss-scale 0.9 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --max-duration 100 \ + --causal 0 \ + --num-paths 100 \ + --decoding-method $m +done +``` + +##### medium-scale model, number of model parameters: 89987295, i.e., 90.0 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-decoding | 2.46 | 5.57 | --epoch 50 --avg 22 | +| attention-decoder-rescoring-no-ngram | 2.23 | 4.98 | --epoch 50 --avg 22 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --ctc-loss-scale 0.1 \ + --attention-decoder-loss-scale 0.9 \ + --max-duration 1200 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-decoding attention-decoder-rescoring-no-ngram; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 22 \ + --exp-dir zipformer/exp \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --attention-decoder-loss-scale 0.9 \ + --max-duration 100 \ + --causal 0 \ + --num-paths 100 \ + --decoding-method $m +done +``` + ##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -15,8 +126,6 @@ You can find a pretrained model, training logs, decoding logs, and decoding resu You can use to deploy it. -Results of the CTC head: - | decoding method | test-clean | test-other | comment | |--------------------------------------|------------|------------|---------------------| | ctc-decoding | 2.29 | 5.14 | --epoch 50 --avg 29 | From d65187ec5245457a43e352f4c0c9930ab2d98225 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 11 Jul 2024 14:45:35 +0800 Subject: [PATCH 192/216] Small fix (#1686) --- egs/librispeech/ASR/zipformer/scaling.py | 5 +++-- egs/librispeech/ASR/zipformer/train.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e7c3f4ab1..3c7e0fa4e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -636,8 +636,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): ) def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: - """ - Forward function. Args: + """Forward function. + + Args: x: a Tensor of shape (batch_size, channels, seq_len) chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. """ diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3797de484..9b6f4a93a 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -406,7 +406,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -429,7 +429,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -848,7 +848,7 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; + warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device From 19048e155b5a07b4dd4d1795815a1a0cd4584e25 Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:12:30 +0900 Subject: [PATCH 193/216] Cast grad_scale in whiten to float (#1663) * cast grad_scale in whiten to float * fix cast in zipformer_lora --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- egs/librispeech/ASR/zipformer_lora/scaling.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3c7e0fa4e..164cc7bfd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1033,7 +1033,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1075,7 +1075,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 3149db9f3..8d7aa8027 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -1137,7 +1137,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1179,7 +1179,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale From f6febd658eb5f1b52771c0b88a5d1205e0d40370 Mon Sep 17 00:00:00 2001 From: Ziwei Li <99643269+NLPvv@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:42:00 +0800 Subject: [PATCH 194/216] "-" replace "_" fix writing error (#1687) --- egs/gigaspeech/KWS/run.sh | 4 ++-- egs/wenetspeech/KWS/run.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh index bd562ce1c..303abd718 100755 --- a/egs/gigaspeech/KWS/run.sh +++ b/egs/gigaspeech/KWS/run.sh @@ -90,7 +90,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 - python ./zipformer/export_onnx_streaming.py \ + python ./zipformer/export-onnx-streaming.py \ --exp-dir zipformer/exp \ --tokens data/lang_bpe_500/tokens.txt \ --epoch 12 \ @@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 - python ./zipformer/export_onnx_streaming.py \ + python ./zipformer/export-onnx-streaming.py \ --exp-dir zipformer/exp_finetune \ --tokens data/lang_bpe_500/tokens.txt \ --epoch 10 \ diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 232ee039a..8472b8531 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -91,7 +91,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 - python ./zipformer/export_onnx_streaming.py \ + python ./zipformer/export-onnx-streaming.py \ --exp-dir zipformer/exp \ --tokens data/lang_partial_tone/tokens.txt \ --epoch 18 \ @@ -187,7 +187,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 - python ./zipformer/export_onnx_streaming.py \ + python ./zipformer/export-onnx-streaming.py \ --exp-dir zipformer/exp_finetune \ --tokens data/lang_partial_tone/tokens.txt \ --epoch 10 \ From 334beed2af5212b1b2b8ca112893b120e83d0516 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 12 Jul 2024 16:50:58 +0800 Subject: [PATCH 195/216] fix usages of returned losses after adding attention-decoder in zipformer (#1689) --- egs/aishell/ASR/zipformer/train.py | 3 +- egs/aishell/ASR/zipformer/train_bbpe.py | 3 +- egs/commonvoice/ASR/zipformer/train.py | 3 +- egs/commonvoice/ASR/zipformer/train_char.py | 3 +- egs/gigaspeech/ASR/zipformer/train.py | 3 +- egs/gigaspeech/KWS/zipformer/train.py | 3 +- egs/ksponspeech/ASR/zipformer/train.py | 3 +- egs/libriheavy/ASR/zipformer/train.py | 4 +- egs/librispeech/ASR/zipformer/finetune.py | 3 +- .../ASR/zipformer_adapter/train.py | 3 +- .../ASR/zipformer_lora/finetune.py | 3 +- egs/librispeech/ASR/zipformer_lora/train.py | 3 +- egs/mdcc/ASR/zipformer/train.py | 3 +- egs/multi_zh-hans/ASR/zipformer/train.py | 3 +- egs/multi_zh_en/ASR/zipformer/train.py | 3 +- egs/reazonspeech/ASR/zipformer/train.py | 4 +- egs/spgispeech/ASR/zipformer/train.py | 4 +- egs/wenetspeech/ASR/zipformer/train.py | 3 +- egs/wenetspeech/KWS/zipformer/finetune.py | 44 ++----------------- egs/wenetspeech/KWS/zipformer/train.py | 3 +- 20 files changed, 42 insertions(+), 62 deletions(-) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index a25979226..cd253c597 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -758,7 +758,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -766,6 +766,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index 0713c5787..46a5506db 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -343,7 +343,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -351,6 +351,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index 5cda9bfd4..271014db0 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -814,7 +814,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -822,6 +822,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index a780bbbbc..0aa7856cc 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -449,7 +449,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -457,6 +457,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index f0ad98147..4c122effe 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -803,7 +803,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -811,6 +811,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index a4d670169..39d8fc6cd 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -806,7 +806,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -814,6 +814,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index b612b6835..485ea69c9 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -787,7 +787,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -795,6 +795,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index 8d4d9d067..357e8a827 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -55,7 +55,6 @@ It supports training with: import argparse import copy import logging -import random import warnings from pathlib import Path from shutil import copyfile @@ -804,7 +803,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -812,6 +811,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 2f7ec0c17..2ff631914 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -893,7 +893,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -901,6 +901,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 6c55896a8..3511590da 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -890,7 +890,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -898,6 +898,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 0464cf65c..3f36f229f 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -903,7 +903,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -911,6 +911,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 3ccf7d2f1..9ab214e86 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -792,7 +792,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -800,6 +800,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py index 2fae66844..730db7718 100755 --- a/egs/mdcc/ASR/zipformer/train.py +++ b/egs/mdcc/ASR/zipformer/train.py @@ -754,7 +754,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -762,6 +762,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 1fc4c35c1..3dbfc48eb 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -832,7 +832,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -840,6 +840,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index 5dba584f7..04bb41214 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -814,7 +814,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -822,6 +822,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/reazonspeech/ASR/zipformer/train.py b/egs/reazonspeech/ASR/zipformer/train.py index 8c6f4bb9a..30bd3efba 100755 --- a/egs/reazonspeech/ASR/zipformer/train.py +++ b/egs/reazonspeech/ASR/zipformer/train.py @@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union import k2 import optim -import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn @@ -791,7 +790,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -799,6 +798,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py index ed66ca29b..dfc21c968 100755 --- a/egs/spgispeech/ASR/zipformer/train.py +++ b/egs/spgispeech/ASR/zipformer/train.py @@ -67,7 +67,6 @@ import torch.nn as nn from asr_datamodule import SPGISpeechAsrDataModule from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -792,7 +791,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -800,6 +799,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 3d3762916..25b16f632 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -758,7 +758,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -766,6 +766,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index 3ad16fd11..d19172b38 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -70,8 +70,7 @@ import copy import logging import warnings from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import k2 import optim @@ -80,7 +79,6 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule from lhotse.cut import Cut, CutSet -from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor @@ -103,14 +101,13 @@ from train import ( from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon @@ -296,7 +293,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -304,6 +301,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 @@ -344,40 +342,6 @@ def compute_loss( return loss, info -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index eddec7303..40960c2ae 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -815,7 +815,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -823,6 +823,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start From d47c078286319dd0aceed51bcb37c41998fa82ce Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sun, 14 Jul 2024 17:30:13 +0800 Subject: [PATCH 196/216] add decoding method of ctc-greedy-search in zipformer recipe (#1690) --- egs/librispeech/ASR/zipformer/ctc_decode.py | 57 +++++++++++++++------ icefall/decode.py | 31 +++++++++++ 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 85ceb61b8..435a79e7f 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -21,7 +21,16 @@ """ Usage: -(1) ctc-decoding +(1) ctc-greedy-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-decoding ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -30,7 +39,7 @@ Usage: --max-duration 600 \ --decoding-method ctc-decoding -(2) 1best +(3) 1best ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -40,7 +49,7 @@ Usage: --hlg-scale 0.6 \ --decoding-method 1best -(3) nbest +(4) nbest ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -50,7 +59,7 @@ Usage: --hlg-scale 0.6 \ --decoding-method nbest -(4) nbest-rescoring +(5) nbest-rescoring ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -62,7 +71,7 @@ Usage: --lm-dir data/lm \ --decoding-method nbest-rescoring -(5) whole-lattice-rescoring +(6) whole-lattice-rescoring ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -74,7 +83,7 @@ Usage: --lm-dir data/lm \ --decoding-method whole-lattice-rescoring -(6) attention-decoder-rescoring-no-ngram +(7) attention-decoder-rescoring-no-ngram ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -84,7 +93,7 @@ Usage: --max-duration 100 \ --decoding-method attention-decoder-rescoring-no-ngram -(7) attention-decoder-rescoring-with-ngram +(8) attention-decoder-rescoring-with-ngram ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -120,6 +129,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.decode import ( + ctc_greedy_search, get_lattice, nbest_decoding, nbest_oracle, @@ -220,26 +230,29 @@ def get_parser(): default="ctc-decoding", help="""Decoding method. Supported values are: - - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. - - (2) 1best. Extract the best path from the decoding lattice as the + - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (3) 1best. Extract the best path from the decoding lattice as the decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path + - (4) nbest. Extract n paths from the decoding lattice; the path with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, + - (5) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an n-gram LM (e.g., a 4-gram LM), the path with the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + - (6) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice is the decoding result. you have trained an RNN LM using ./rnn_lm/train.py - - (6) nbest-oracle. Its WER is the lower bound of any n-best + - (7) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. - - (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + - (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding lattice, rescore them with the attention decoder. - - (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM rescored lattice, rescore them with the attention decoder. """, ) @@ -381,6 +394,15 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) ctc_output = model.ctc_output(encoder_out) # (N, T, C) + if params.decoding_method == "ctc-greedy-search": + hyps = ctc_greedy_search(ctc_output, encoder_out_lens) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -684,6 +706,7 @@ def main(): params.update(vars(args)) assert params.decoding_method in ( + "ctc-greedy-search", "ctc-decoding", "1best", "nbest", @@ -733,7 +756,9 @@ def main(): params.eos_id = 1 params.sos_id = 1 - if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: + if params.decoding_method in [ + "ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" + ]: HLG = None H = k2.ctc_topo( max_token=max_token_id, diff --git a/icefall/decode.py b/icefall/decode.py index 3abd5648a..b17de0ba7 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1473,3 +1473,34 @@ def rescore_with_rnn_lm( key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa ans[key] = best_path return ans + + +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def ctc_greedy_search( + ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor +) -> List[List[int]]: + """CTC greedy search. + + Args: + ctc_output: (batch, seq_len, vocab_size) + encoder_out_lens: (batch,) + Returns: + List[List[int]]: greedy search result + """ + batch = ctc_output.shape[0] + index = ctc_output.argmax(dim=-1) # (batch, seq_len) + hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)] + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps From 2e132987174d1b1e5743d2a5a351a54493803b8b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 15 Jul 2024 12:01:47 +0800 Subject: [PATCH 197/216] Refactor ctc greedy search. (#1691) Use torch.unique_consecutive() to avoid reinventing the wheel. --- icefall/decode.py | 25 ++++++++------------ test/test_ctc_greedy_search.py | 42 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 16 deletions(-) create mode 100755 test/test_ctc_greedy_search.py diff --git a/icefall/decode.py b/icefall/decode.py index b17de0ba7..dd3af1e99 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1475,21 +1475,10 @@ def rescore_with_rnn_lm( return ans -def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: - # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py - new_hyp: List[int] = [] - cur = 0 - while cur < len(hyp): - if hyp[cur] != 0: - new_hyp.append(hyp[cur]) - prev = cur - while cur < len(hyp) and hyp[cur] == hyp[prev]: - cur += 1 - return new_hyp - - def ctc_greedy_search( - ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + blank_id: int = 0, ) -> List[List[int]]: """CTC greedy search. @@ -1501,6 +1490,10 @@ def ctc_greedy_search( """ batch = ctc_output.shape[0] index = ctc_output.argmax(dim=-1) # (batch, seq_len) - hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)] - hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + hyps = [ + torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch) + ] + + hyps = [h[h != blank_id].tolist() for h in hyps] + return hyps diff --git a/test/test_ctc_greedy_search.py b/test/test_ctc_greedy_search.py new file mode 100755 index 000000000..a82b2d8f1 --- /dev/null +++ b/test/test_ctc_greedy_search.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +import torch + +from icefall.decode import ctc_greedy_search + + +def test(): + log_probs = torch.tensor( + [ + [ + [10, 1, 2, 1, 1, 3, 2, 3], + [10, 3, 2, 2, 1, 3, 2, 3], + [1, 10, 2, 2, 1, 3, 2, 3], + [1, 10, 2, 2, 1, 3, 2, 3], + [1, 1, 10, 1, 1, 3, 2, 3], + [10, 1, 1, 1, 1, 3, 2, 3], + [1, 1, 1, 10, 1, 3, 2, 3], + ], + [ + [10, 1, 2, 1, 1, 3, 2, 3], + [10, 3, 2, 2, 1, 3, 2, 3], + [1, 10, 2, 2, 1, 3, 2, 3], + [1, 10, 2, 2, 1, 3, 2, 3], + [1, 1, 10, 1, 1, 3, 2, 3], + [10, 1, 1, 1, 1, 3, 2, 3], + [1, 1, 1, 10, 1, 3, 2, 3], + ], + ], + dtype=torch.float32, + ).log_softmax(dim=-1) + + log_probs_length = torch.tensor([7, 6]) + + hyps = ctc_greedy_search(log_probs, log_probs_length) + + assert hyps[0] == [1, 2, 3], hyps[0] + assert hyps[1] == [1, 2], hyps[1] + + +if __name__ == "__main__": + test() From 11151415f371795d1cead5cc96563e60e269a78a Mon Sep 17 00:00:00 2001 From: zzasdf <68544676+zzasdf@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:47:43 +0800 Subject: [PATCH 198/216] fix error in accum_grad (#1693) --- egs/librispeech/SSL/hubert/finetune.py | 2 +- egs/librispeech/SSL/hubert/finetune_ce.py | 2 +- egs/librispeech/SSL/hubert/pretrain.py | 2 +- egs/librispeech/SSL/hubert/pretrain_ce.py | 2 +- egs/librispeech/SSL/zipformer/finetune.py | 2 +- egs/librispeech/SSL/zipformer/pretrain.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 201847aed..17daa3c9d 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -948,7 +948,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index e69a5a8cd..2723cc770 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -948,7 +948,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index d9bda8857..f183d90fd 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -774,7 +774,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 24c0d4d3a..94948695d 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -774,7 +774,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index bbb445320..c907b41c5 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -1245,7 +1245,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 5f547e0b8..937fb382e 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -1072,7 +1072,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value From 4af81af5a60ead0e3f71961936d2dbef482d648f Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 18 Jul 2024 21:05:59 +0800 Subject: [PATCH 199/216] Update Zipformer-xl 700M Results on multi-hans-zh (#1694) * add blank penalty * update zipformer-xl results * fix typo --- egs/multi_zh-hans/ASR/RESULTS.md | 61 +++++++++++++++++++++++ egs/multi_zh-hans/ASR/zipformer/decode.py | 17 ++++++- 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index 6f5c93ea9..622218d02 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -43,6 +43,66 @@ Fine-tuned models, training logs, decoding logs, tensorboard and decoding result are available at +### Multi Chinese datasets char-based training results (streaming) on zipformer-xl model + +#### Streaming (with CTC head) + +The training command for extra-large model (num of params : ~700M): + +Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features. + +``` +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 20 \ + --use-fp16 1 \ + --max-duration 1200 \ + --num-workers 8 \ + --use-ctc 1 \ + --exp-dir zipformer/exp-xl \ + --causal 1 \ + --num-encoder-layers 2,3,5,6,5,3 \ + --feedforward-dim 1536,2048,3072,4096,3072,1536 \ + --encoder-dim 512,768,1024,1536,1024,512 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --decoder-dim 768 --joiner-dim 768 \ + --value-head-dim 18 \ + --query-head-dim 48 \ + --num-heads 4,4,4,8,4,4 + +``` + +The decoding command for transducer greedy search: + +``` +./zipformer/decode.py \ + --epoch 999 \ + --avg 1 \ + --causal 1 \ + --use-averaged-model False \ + --chunk_size -1 + --left-context-frames -1 \ + --use-ctc 1 \ + --exp-dir zipformer/exp-xl \ + --max-duration 1200 \ + --num-encoder-layers 2,3,5,6,5,3 \ + --feedforward-dim 1536,2048,3072,4096,3072,1536 \ + --encoder-dim 512,768,1024,1536,1024,512 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --decoder-dim 768 --joiner-dim 768 \ + --value-head-dim 18 \ + --query-head-dim 48 \ + --num-heads 4,4,4,8,4,4 +``` + +Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled). + +| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| +| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| Transducer Greedy Offline | 21.67 | 23.43 | 1.22 | 1.31 | 3.17 | 3.27 | 14.64 | 2.42 | 1.99 | 5.00 | 2.29 | 5.98 | 5.15 | 5.85 | 6.89 | + +Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-xl ### Multi Chinese datasets char-based training results (streaming) on zipformer large model #### Streaming (with CTC head) @@ -64,6 +124,7 @@ Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech --num-encoder-layers 2,2,4,5,4,2 \ --feedforward-dim 768,1024,1536,2048,1536,768 \ --encoder-dim 256,384,512,768,512,256 \ + --blank-penalty 0.7 \ --encoder-unmasked-dim 192,192,256,320,256,192 ``` diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 5993f243f..a1d018cd2 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -303,6 +303,17 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) add_model_arguments(parser) return parser @@ -431,6 +442,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -455,6 +467,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "beam_search": hyp = beam_search( @@ -468,8 +481,9 @@ def decode_one_batch( ) hyps.append(sp.decode(hyp).split()) + key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search_" + key: hyps} elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -657,6 +671,7 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" if params.use_averaged_model: params.suffix += "-use-averaged-model" From 3b257dd5ae79bff99470ec1cbbeaa8fae84f956a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 25 Jul 2024 16:46:24 +0800 Subject: [PATCH 200/216] Add docker images for torch 2.4 (#1704) --- .../scripts/docker/generate_build_matrix.py | 6 +- .github/workflows/build-docker-image.yml | 34 ++++++++- .github/workflows/run-docker-image.yml | 34 ++++++++- docker/torch2.0.0-cuda11.7.dockerfile | 2 +- docker/torch2.1.0-cuda11.8.dockerfile | 2 +- docker/torch2.1.0-cuda12.1.dockerfile | 2 +- docker/torch2.2.0-cuda11.8.dockerfile | 2 +- docker/torch2.2.0-cuda12.1.dockerfile | 2 +- docker/torch2.2.1-cuda11.8.dockerfile | 2 +- docker/torch2.2.1-cuda12.1.dockerfile | 2 +- docker/torch2.2.2-cuda11.8.dockerfile | 2 +- docker/torch2.2.2-cuda12.1.dockerfile | 2 +- docker/torch2.3.1-cuda11.8.dockerfile | 2 +- docker/torch2.3.1-cuda12.1.dockerfile | 2 +- docker/torch2.4.0-cuda11.8.dockerfile | 73 +++++++++++++++++++ docker/torch2.4.0-cuda12.1.dockerfile | 73 +++++++++++++++++++ docker/torch2.4.0-cuda12.4.dockerfile | 73 +++++++++++++++++++ 17 files changed, 301 insertions(+), 14 deletions(-) create mode 100644 docker/torch2.4.0-cuda11.8.dockerfile create mode 100644 docker/torch2.4.0-cuda12.1.dockerfile create mode 100644 docker/torch2.4.0-cuda12.4.dockerfile diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 7f13c59bd..5a763e044 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20240223" kaldifeat_version = "1.25.4.dev20240223" - version = "20240606" + version = "20240725" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] torch_version += ["1.13.0", "1.13.1"] @@ -53,6 +53,7 @@ def get_matrix(): torch_version += ["2.1.0", "2.1.1", "2.1.2"] torch_version += ["2.2.0", "2.2.1", "2.2.2"] torch_version += ["2.3.0", "2.3.1"] + torch_version += ["2.4.0"] matrix = [] for p in python_version: @@ -78,6 +79,9 @@ def get_matrix(): elif t == "2.3.1": k2_version_2 = "1.24.4.dev20240606" kaldifeat_version_2 = "1.25.4.dev20240606" + elif t == "2.4.0": + k2_version_2 = "1.24.4.dev20240725" + kaldifeat_version_2 = "1.25.4.dev20240725" matrix.append( { diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index 23dcb519f..77480bd3e 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout @@ -37,6 +37,38 @@ jobs: rm -rf /opt/hostedtoolcache df -h + - name: Free more space + shell: bash + run: | + # https://github.com/orgs/community/discussions/25678 + cd /opt + find . -maxdepth 1 -mindepth 1 '!' -path ./containerd '!' -path ./actionarchivecache '!' -path ./runner '!' -path ./runner-cache -exec rm -rf '{}' ';' + + sudo rm -rf /usr/share/dotnet + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + # this might remove tools that are actually needed, + # if set to "true" but frees about 6 GB + tool-cache: false + + # all of these default to true, but feel free to set to + # "false" if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: false + swap-storage: true + + - name: Check space + shell: bash + run: | + df -h + - name: Log in to Docker Hub uses: docker/login-action@v2 with: diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index 336d930ca..05c630ad5 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v2 @@ -28,6 +28,38 @@ jobs: rm -rf /opt/hostedtoolcache df -h + - name: Free more space + shell: bash + run: | + # https://github.com/orgs/community/discussions/25678 + cd /opt + find . -maxdepth 1 -mindepth 1 '!' -path ./containerd '!' -path ./actionarchivecache '!' -path ./runner '!' -path ./runner-cache -exec rm -rf '{}' ';' + + sudo rm -rf /usr/share/dotnet + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + # this might remove tools that are actually needed, + # if set to "true" but frees about 6 GB + tool-cache: false + + # all of these default to true, but feel free to set to + # "false" if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: false + swap-storage: true + + - name: Check space + shell: bash + run: | + df -h + - name: Run the build process with Docker uses: addnab/docker-run-action@v3 with: diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index e2e27b55d..22f0a7a95 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index de1e07e69..e87e99468 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index 89303797a..b2628ef9c 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.0-cuda11.8.dockerfile b/docker/torch2.2.0-cuda11.8.dockerfile index 3364477a8..0f65f9595 100644 --- a/docker/torch2.2.0-cuda11.8.dockerfile +++ b/docker/torch2.2.0-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.0-cuda12.1.dockerfile b/docker/torch2.2.0-cuda12.1.dockerfile index 3cc41902d..7a544c0b2 100644 --- a/docker/torch2.2.0-cuda12.1.dockerfile +++ b/docker/torch2.2.0-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.1-cuda11.8.dockerfile b/docker/torch2.2.1-cuda11.8.dockerfile index 76b785622..0c04314a7 100644 --- a/docker/torch2.2.1-cuda11.8.dockerfile +++ b/docker/torch2.2.1-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.1-cuda12.1.dockerfile b/docker/torch2.2.1-cuda12.1.dockerfile index 55bdfa4d7..5c4c9a99a 100644 --- a/docker/torch2.2.1-cuda12.1.dockerfile +++ b/docker/torch2.2.1-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.2-cuda11.8.dockerfile b/docker/torch2.2.2-cuda11.8.dockerfile index 02de82c50..d712dd57a 100644 --- a/docker/torch2.2.2-cuda11.8.dockerfile +++ b/docker/torch2.2.2-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.2-cuda12.1.dockerfile b/docker/torch2.2.2-cuda12.1.dockerfile index 44ad38b8e..af0e966e7 100644 --- a/docker/torch2.2.2-cuda12.1.dockerfile +++ b/docker/torch2.2.2-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.3.1-cuda11.8.dockerfile b/docker/torch2.3.1-cuda11.8.dockerfile index 545b42e9f..ee07a4c24 100644 --- a/docker/torch2.3.1-cuda11.8.dockerfile +++ b/docker/torch2.3.1-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.3.1-cuda12.1.dockerfile b/docker/torch2.3.1-cuda12.1.dockerfile index ca13752e4..f5bac35a2 100644 --- a/docker/torch2.3.1-cuda12.1.dockerfile +++ b/docker/torch2.3.1-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.4.0-cuda11.8.dockerfile b/docker/torch2.4.0-cuda11.8.dockerfile new file mode 100644 index 000000000..a5ffc0bb5 --- /dev/null +++ b/docker/torch2.4.0-cuda11.8.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240725+cuda11.8.torch2.4.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240725+cuda11.8.torch2.4.0" +ARG TORCHAUDIO_VERSION="2.4.0+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.4.0-cuda12.1.dockerfile b/docker/torch2.4.0-cuda12.1.dockerfile new file mode 100644 index 000000000..01208ce2d --- /dev/null +++ b/docker/torch2.4.0-cuda12.1.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240725+cuda12.1.torch2.4.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240725+cuda12.1.torch2.4.0" +ARG TORCHAUDIO_VERSION="2.4.0+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.4.0-cuda12.4.dockerfile b/docker/torch2.4.0-cuda12.4.dockerfile new file mode 100644 index 000000000..d0d300cfa --- /dev/null +++ b/docker/torch2.4.0-cuda12.4.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240725+cuda12.4.torch2.4.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240725+cuda12.4.torch2.4.0" +ARG TORCHAUDIO_VERSION="2.4.0+cu124" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall From 1730fce688aa4cb6c3162ed860e29c6a72da1604 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Tue, 13 Aug 2024 17:02:14 +0200 Subject: [PATCH 201/216] split `save_results()` -> `save_asr_output()` + `save_wer_results()` (#1712) - the idea is to support `--skip-scoring` argument passed to a decoding script - created for Transducer decoding (non-streaming, streaming) - it can be done also for CTC decoding... (not yet) - also added `--label` for extra label in `streaming_decode.py` - and also added `set_caching_enabled(True)`, which has no effect on librispeech, but it leads to faster runtime on DBs with long recordings (assuming `librispeech/zipformer` scripts are the example scripts for other setups) --- egs/librispeech/ASR/zipformer/ctc_decode.py | 96 +++++++++---- egs/librispeech/ASR/zipformer/decode.py | 136 +++++++++++------- .../ASR/zipformer/streaming_decode.py | 88 +++++++++--- 3 files changed, 217 insertions(+), 103 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 435a79e7f..9db429959 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -120,6 +120,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -296,6 +297,13 @@ def get_parser(): """, ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + add_model_arguments(parser) return parser @@ -455,7 +463,7 @@ def decode_one_batch( # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] hyps = [s.split() for s in hyps] key = "ctc-decoding" - return {key: hyps} + return {key: hyps} # note: returns words if params.decoding_method == "attention-decoder-rescoring-no-ngram": best_path_dict = rescore_with_attention_decoder_no_ngram( @@ -492,7 +500,7 @@ def decode_one_batch( ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa return {key: hyps} if params.decoding_method in ["1best", "nbest"]: @@ -500,7 +508,7 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - key = "no_rescore" + key = "no-rescore" else: best_path = nbest_decoding( lattice=lattice, @@ -508,11 +516,11 @@ def decode_one_batch( use_double_scores=params.use_double_scores, nbest_scale=params.nbest_scale, ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} + return {key: hyps} # note: returns BPE tokens assert params.decoding_method in [ "nbest-rescoring", @@ -646,7 +654,27 @@ def decode_dataset( return results -def save_results( +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + ) + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], @@ -661,32 +689,30 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) test_set_wers[key] = wer - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -705,6 +731,9 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( "ctc-greedy-search", "ctc-decoding", @@ -719,9 +748,9 @@ def main(): params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -730,11 +759,11 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -940,12 +969,19 @@ def main(): G=G, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index df2d555a0..cbfb3728e 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -121,6 +121,7 @@ from beam_search import ( modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) +from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm @@ -369,6 +370,14 @@ def get_parser(): modified_beam_search_LODR. """, ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + add_model_arguments(parser) return parser @@ -590,21 +599,23 @@ def decode_one_batch( ) hyps.append(sp.decode(hyp).split()) + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" - return {key: hyps} + return {prefix: hyps} elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" + prefix += f"_beam-size-{params.beam_size}" if params.decoding_method in ( "modified_beam_search_lm_rescore", "modified_beam_search_lm_rescore_LODR", @@ -617,10 +628,11 @@ def decode_one_batch( return ans else: if params.has_contexts: - prefix += f"-context-score-{params.context_score}" + prefix += f"_context-score-{params.context_score}" return {prefix: hyps} else: - return {f"beam_size_{params.beam_size}": hyps} + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} def decode_dataset( @@ -707,46 +719,58 @@ def decode_dataset( return results -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -762,6 +786,9 @@ def main(): params = get_params() params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( "greedy_search", "beam_search", @@ -783,9 +810,9 @@ def main(): params.has_contexts = False if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -794,20 +821,20 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" if params.decoding_method in ( "modified_beam_search", "modified_beam_search_LODR", @@ -815,19 +842,19 @@ def main(): if params.has_contexts: params.suffix += f"-context-score-{params.context_score}" else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" if "LODR" in params.decoding_method: params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" ) if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -1038,12 +1065,19 @@ def main(): ngram_lm_scale=ngram_lm_scale, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 360523b8e..ebcafbf87 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -43,7 +43,7 @@ import torch from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet +from lhotse import CutSet, set_caching_enabled from streaming_beam_search import ( fast_beam_search_one_best, greedy_search, @@ -76,6 +76,13 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--label", + type=str, + default="", + help="""Extra label of the decoding run.""", + ) + parser.add_argument( "--epoch", type=int, @@ -188,6 +195,14 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + + add_model_arguments(parser) return parser @@ -640,46 +655,60 @@ def decode_dataset( return {key: decode_results} -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[List[str], List[str]]]], ): - test_set_wers = dict() + """ + Save text produced by ASR. + """ for key, results in results_dict.items(): - recog_path = ( + recogs_filename = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recogs_filename, texts=results) + logging.info(f"The transcripts are stored in {recogs_filename}") + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_filename, "w") as f: + with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( + + wer_filename = ( params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -694,6 +723,9 @@ def main(): params = get_params() params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: @@ -706,18 +738,21 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" # for fast_beam_search if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" if params.use_averaged_model: params.suffix += "-use-averaged-model" + if params.label: + params.suffix += f"-{params.label}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -845,12 +880,21 @@ def main(): decoding_graph=decoding_graph, ) - save_results( + + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") From 6ac3343ce53fa6685fca0f876f2c6245af4caac5 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 16 Aug 2024 20:13:02 +0800 Subject: [PATCH 202/216] fix path in README.md (#1722) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 31e514606..81cfc03ce 100644 --- a/README.md +++ b/README.md @@ -375,7 +375,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [libricss]: egs/libricss/SURT [libriheavy]: egs/libriheavy/ASR [mgb2]: egs/mgb2/ASR -[peoplespeech]: egs/peoplespeech/ASR +[peoplespeech]: egs/peoples_speech/ASR [spgispeech]: egs/spgispeech/ASR [voxpopuli]: egs/voxpopuli/ASR [xbmu-amdo31]: egs/xbmu-amdo31/ASR From 595297229405fa74ec0dd53e0e7d0ce051802148 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Sat, 17 Aug 2024 13:24:38 +0800 Subject: [PATCH 203/216] Keep the custom fields in libriheavy manifest (#1719) --- egs/libriheavy/ASR/local/prepare_manifest.py | 10 +++++++--- egs/libriheavy/ASR/prepare.sh | 7 ++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py index 42f392cae..d7e184d86 100755 --- a/egs/libriheavy/ASR/local/prepare_manifest.py +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -29,17 +29,21 @@ def simple_cleanup(text: str) -> str: # Assign text of the supervisions and remove unnecessary entries. def main(): - assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" + assert ( + len(sys.argv) == 4 + ), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS" fname = Path(sys.argv[1]).name oname = Path(sys.argv[2]) / fname + keep_custom_fields = bool(sys.argv[3]) with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: for line in fin: cut = json.loads(line) cut["supervisions"][0]["text"] = simple_cleanup( cut["supervisions"][0]["custom"]["texts"][0] ) - del cut["supervisions"][0]["custom"] - del cut["custom"] + if not keep_custom_fields: + del cut["supervisions"][0]["custom"] + del cut["custom"] fout.write((json.dumps(cut) + "\n").encode()) diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh index b0736c98b..366a1459f 100755 --- a/egs/libriheavy/ASR/prepare.sh +++ b/egs/libriheavy/ASR/prepare.sh @@ -29,6 +29,11 @@ export CUDA_VISIBLE_DEVICES="" # - speech dl_dir=$PWD/download +# If you want to do PromptASR experiments, please set it to True +# as this will keep the texts and pre_text information required for +# the training of PromptASR. +keep_custom_fields=False + . shared/parse_options.sh || exit 1 # vocab size for sentence piece models. @@ -134,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then for subset in small medium large dev test_clean test_other; do if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then log "Prepare manifest for subset : ${subset}" - ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir + ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir $keep_custom_fields fi done fi From 3fc06cc2b9120a79a3e061bf35cef8d7220a42f3 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:27:25 +0800 Subject: [PATCH 204/216] Support AudioSet training with weighted sampler (#1727) --- egs/audioset/AT/RESULTS.md | 36 +++++-- egs/audioset/AT/local/compute_weight.py | 73 ++++++++++++++ egs/audioset/AT/prepare.sh | 13 ++- egs/audioset/AT/zipformer/at_datamodule.py | 107 ++++++++++++++++----- egs/audioset/AT/zipformer/train.py | 11 ++- 5 files changed, 207 insertions(+), 33 deletions(-) create mode 100644 egs/audioset/AT/local/compute_weight.py diff --git a/egs/audioset/AT/RESULTS.md b/egs/audioset/AT/RESULTS.md index 0128b7018..36613db03 100644 --- a/egs/audioset/AT/RESULTS.md +++ b/egs/audioset/AT/RESULTS.md @@ -35,16 +35,40 @@ python zipformer/train.py \ --master-port 13455 ``` +We recommend that you train the model with weighted sampler, as the model converges +faster with better performance: + +| Model | mAP | +| ------ | ------- | +| Zipformer-AT, train with weighted sampler | 46.6 | + The evaluation command is: ```bash -python zipformer/evaluate.py \ - --epoch 32 \ - --avg 8 \ - --exp-dir zipformer/exp_at_as_full \ - --max-duration 500 +export CUDA_VISIBLE_DEVICES="4,5,6,7" +subset=full +weighted_sampler=1 +bucket_sampler=0 +lr_epochs=15 + +python zipformer/train.py \ + --world-size 4 \ + --audioset-subset $subset \ + --num-epochs 120 \ + --start-epoch 1 \ + --use-fp16 1 \ + --num-events 527 \ + --lr-epochs $lr_epochs \ + --exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \ + --weighted-sampler $weighted_sampler \ + --bucketing-sampler $bucket_sampler \ + --max-duration 1000 \ + --enable-musan True \ + --master-port 13452 ``` +The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler + #### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M @@ -92,4 +116,4 @@ python zipformer/evaluate.py \ --encoder-unmasked-dim 192,192,192,192,192,192 \ --exp-dir zipformer/exp_small_at_as_full \ --max-duration 500 -``` \ No newline at end of file +``` diff --git a/egs/audioset/AT/local/compute_weight.py b/egs/audioset/AT/local/compute_weight.py new file mode 100644 index 000000000..a0deddc0c --- /dev/null +++ b/egs/audioset/AT/local/compute_weight.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file generates the manifest and computes the fbank features for AudioSet +dataset. The generated manifests and features are stored in data/fbank. +""" + +import argparse + +import lhotse +from lhotse import load_manifest + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz" + ) + + parser.add_argument( + "--output", + type=str, + required=True, + ) + return parser + + +def main(): + # Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py + parser = get_parser() + args = parser.parse_args() + + cuts = load_manifest(args.input_manifest) + + print(f"A total of {len(cuts)} cuts.") + + label_count = [0] * 527 # a total of 527 classes + for c in cuts: + audio_event = c.supervisions[0].audio_event + labels = list(map(int, audio_event.split(";"))) + for label in labels: + label_count[label] += 1 + + with open(args.output, "w") as f: + for c in cuts: + audio_event = c.supervisions[0].audio_event + labels = list(map(int, audio_event.split(";"))) + weight = 0 + for label in labels: + weight += 1000 / (label_count[label] + 0.01) + f.write(f"{c.id} {weight}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/audioset/AT/prepare.sh b/egs/audioset/AT/prepare.sh index f7f73a008..8beaf2d86 100755 --- a/egs/audioset/AT/prepare.sh +++ b/egs/audioset/AT/prepare.sh @@ -10,6 +10,7 @@ stage=-1 stop_stage=4 dl_dir=$PWD/download +fbank_dir=data/fbank # we assume that you have your downloaded the AudioSet and placed # it under $dl_dir/audioset, the folder structure should look like @@ -49,7 +50,6 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set" - fbank_dir=data/fbank if [! -e $fbank_dir/.balanced.done]; then python local/generate_audioset_manifest.py \ --dataset-dir $dl_dir/audioset \ @@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then touch data/fbank/.musan.done fi fi + +# The following stages are required to do weighted-sampling training +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare for weighted-sampling training" + if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then + lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz + fi + python ./local/compute_weight.py \ + --input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \ + --output $fbank_dir/sampling_weights_full.txt +fi diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py index ac8671fa6..b7df01539 100644 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -31,6 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures PrecomputedFeatures, SimpleCutSampler, SpecAugment, + WeightedSimpleCutSampler, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -99,6 +100,20 @@ class AudioSetATDatamodule: help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) + group.add_argument( + "--weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "It cannot be used together with bucketing sampler", + ) + group.add_argument( + "--num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler", + ) group.add_argument( "--bucketing-sampler", type=str2bool, @@ -295,6 +310,9 @@ class AudioSetATDatamodule: ) if self.args.bucketing_sampler: + assert ( + not self.args.weighted_sampler + ), "weighted sampling is not supported in bucket sampler" logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( cuts_train, @@ -304,13 +322,26 @@ class AudioSetATDatamodule: drop_last=self.args.drop_last, ) else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - drop_last=self.args.drop_last, - ) + if self.args.weighted_sampler: + # assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset" + logging.info("Using weighted SimpleCutSampler") + weights = self.audioset_sampling_weights() + train_sampler = WeightedSimpleCutSampler( + cuts_train, + weights, + num_samples=self.args.num_samples, + max_duration=self.args.max_duration, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + ) logging.info("About to create train dataloader") if sampler_state_dict is not None: @@ -373,11 +404,9 @@ class AudioSetATDatamodule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = AudioTaggingDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)() - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( @@ -397,21 +426,30 @@ class AudioSetATDatamodule: @lru_cache() def audioset_train_cuts(self) -> CutSet: logging.info("About to get the audioset training cuts.") - balanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" - ) - if self.args.audioset_subset == "full": - unbalanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" - ) - cuts = CutSet.mux( - balanced_cuts, - unbalanced_cuts, - weights=[20000, 2000000], - stop_early=True, + if not self.args.weighted_sampler: + balanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" ) + if self.args.audioset_subset == "full": + unbalanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" + ) + cuts = CutSet.mux( + balanced_cuts, + unbalanced_cuts, + weights=[20000, 2000000], + stop_early=True, + ) + else: + cuts = balanced_cuts else: - cuts = balanced_cuts + # assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet" + cuts = load_manifest( + self.args.manifest_dir + / f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz" + ) + logging.info(f"Get {len(cuts)} cuts in total.") + return cuts @lru_cache() @@ -420,3 +458,22 @@ class AudioSetATDatamodule: return load_manifest_lazy( self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz" ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 2d193030a..67c703364 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -789,12 +789,14 @@ def train_one_epoch( rank=0, ) + num_samples = 0 for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) params.batch_idx_train += 1 batch_size = batch["inputs"].size(0) + num_samples += batch_size try: with torch.cuda.amp.autocast(enabled=params.use_fp16): @@ -919,6 +921,12 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) + if num_samples > params.num_samples: + logging.info( + f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch" + ) + break + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -1032,7 +1040,8 @@ def run(rank, world_size, args): return True - train_cuts = train_cuts.filter(remove_short_and_long_utt) + if not params.weighted_sampler: + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint From 3b434fe83c40eaf3c4739c26d11aa3a3b8af8ddc Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 23 Aug 2024 09:33:46 +0800 Subject: [PATCH 205/216] fix triton onnx export (#1730) --- egs/librispeech/ASR/zipformer/export-onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index ed8a0ef0f..ca3cbf0d5 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module): - encoder_out_lens, A 1-D tensor of shape (N,) """ x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) x = x.permute(1, 0, 2) encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) From a6c02a4d8c5c5db3f30899ca622a813640aba63f Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Fri, 23 Aug 2024 09:42:22 +0800 Subject: [PATCH 206/216] zipformer BF16 training recipe (#1700) Support Zipformer AMP +BF16 training --- egs/librispeech/ASR/RESULTS.md | 17 +++++++++ egs/librispeech/ASR/zipformer/scaling.py | 12 +++--- egs/librispeech/ASR/zipformer/train.py | 47 ++++++++++++++++++------ 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 66b147764..bc7d8a5ef 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -307,6 +307,23 @@ done To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html). +We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.** + +The amp+bf16 training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 0 \ + --use-bf16 1 \ + --exp-dir zipformer/exp_amp_bf16 \ + --causal 0 \ + --full-libri 1 \ + --max-duration 1000 +``` + ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 164cc7bfd..2a40b8d64 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function): # (presumably) that op does not support float16, and autocast # is enabled. if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) + ans = ans.to(torch.get_autocast_gpu_dtype()) ctx.save_for_backward(ans) ctx.x_dtype = x.dtype ctx.dim = dim @@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) s = torch.sigmoid(x - 1.0) @@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function): d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod @@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function): def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function): d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9b6f4a93a..9c1c7f5a7 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -521,6 +521,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + add_model_arguments(parser) return parser @@ -1027,7 +1034,9 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,9 +1056,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except Exception as e: - logging.info( - f"Caught exception: {e}." - ) + logging.info(f"Caught exception: {e}.") save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise @@ -1090,7 +1097,7 @@ def train_one_epoch( rank=rank, ) - if batch_idx % 100 == 0 and params.use_fp16: + if batch_idx % 100 == 0 and params.use_autocast: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. @@ -1109,14 +1116,14 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") ) if tb_writer is not None: @@ -1128,7 +1135,7 @@ def train_one_epoch( tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: + if params.use_autocast: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) @@ -1204,9 +1211,25 @@ def run(rank, world_size, args): params.ctc_loss_scale = 1.0 else: assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( - params.ctc_loss_scale, params.attention_decoder_loss_scale + params.ctc_loss_scale, + params.attention_decoder_loss_scale, ) + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + logging.info(params) logging.info("About to create model") @@ -1339,7 +1362,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): loss, _ = compute_loss( params=params, model=model, From cea0dbe7b1cd4d5b7512c7974e53034ef456dd70 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 28 Aug 2024 12:15:01 +0800 Subject: [PATCH 207/216] fix gigaspeech_prepare.sh (#1734) --- egs/gigaspeech/ASR/prepare.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index 5e54b669a..219197e13 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -161,14 +161,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Split XL subset into pieces (may take 30 minutes)" split_dir=data/fbank/XL_split if [ ! -f $split_dir/.split_completed ]; then - lhotse split-lazy ./data/fbank/cuts_XL_raw.jsonl.gz $split_dir $num_per_split + lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $num_per_split touch $split_dir/.split_completed fi fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Compute features for XL" - num_splits=$(find data/fbank/XL_split -name "cuts_XL_raw.*.jsonl.gz" | wc -l) + num_splits=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL_raw.*.jsonl.gz" | wc -l) python3 ./local/compute_fbank_gigaspeech_splits.py \ --num-workers 20 \ --batch-duration 600 \ @@ -177,9 +177,9 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Combine features for XL (may take 3 hours)" - if [ ! -f data/fbank/cuts_XL.jsonl.gz ]; then - pieces=$(find data/fbank/XL_split -name "cuts_XL.*.jsonl.gz") - lhotse combine $pieces data/fbank/cuts_XL.jsonl.gz + if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then + pieces=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL.*.jsonl.gz") + lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz fi fi From f233ffa02ae248e4ad2c526d5c35c4a9ade601f5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 7 Sep 2024 18:17:04 +0800 Subject: [PATCH 208/216] Add docker images for torch 2.4.1 (#1743) --- .../scripts/docker/generate_build_matrix.py | 6 +- .github/workflows/build-docker-image.yml | 2 +- docker/torch2.4.1-cuda11.8.dockerfile | 73 +++++++++++++++++++ docker/torch2.4.1-cuda12.1.dockerfile | 73 +++++++++++++++++++ docker/torch2.4.1-cuda12.4.dockerfile | 73 +++++++++++++++++++ docs/source/docker/intro.rst | 6 ++ 6 files changed, 231 insertions(+), 2 deletions(-) create mode 100644 docker/torch2.4.1-cuda11.8.dockerfile create mode 100644 docker/torch2.4.1-cuda12.1.dockerfile create mode 100644 docker/torch2.4.1-cuda12.4.dockerfile diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 5a763e044..492d3ed47 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20240223" kaldifeat_version = "1.25.4.dev20240223" - version = "20240725" + version = "20240905" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] torch_version += ["1.13.0", "1.13.1"] @@ -54,6 +54,7 @@ def get_matrix(): torch_version += ["2.2.0", "2.2.1", "2.2.2"] torch_version += ["2.3.0", "2.3.1"] torch_version += ["2.4.0"] + torch_version += ["2.4.1"] matrix = [] for p in python_version: @@ -82,6 +83,9 @@ def get_matrix(): elif t == "2.4.0": k2_version_2 = "1.24.4.dev20240725" kaldifeat_version_2 = "1.25.4.dev20240725" + elif t == "2.4.1": + k2_version_2 = "1.24.4.dev20240905" + kaldifeat_version_2 = "1.25.4.dev20240905" matrix.append( { diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index 77480bd3e..a473590a3 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.4.1-cuda12.4", "torch2.4.1-cuda12.1", "torch2.4.1-cuda11.8", "torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout diff --git a/docker/torch2.4.1-cuda11.8.dockerfile b/docker/torch2.4.1-cuda11.8.dockerfile new file mode 100644 index 000000000..bc1782b0d --- /dev/null +++ b/docker/torch2.4.1-cuda11.8.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.1-cuda11.8-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240905+cuda11.8.torch2.4.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda11.8.torch2.4.1" +ARG TORCHAUDIO_VERSION="2.4.1+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.4.1-cuda12.1.dockerfile b/docker/torch2.4.1-cuda12.1.dockerfile new file mode 100644 index 000000000..df2ea61a4 --- /dev/null +++ b/docker/torch2.4.1-cuda12.1.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.1-cuda12.1-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240905+cuda12.1.torch2.4.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda12.1.torch2.4.1" +ARG TORCHAUDIO_VERSION="2.4.1+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.4.1-cuda12.4.dockerfile b/docker/torch2.4.1-cuda12.4.dockerfile new file mode 100644 index 000000000..4d6da2804 --- /dev/null +++ b/docker/torch2.4.1-cuda12.4.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.1-cuda12.4-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240905+cuda12.4.torch2.4.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda12.4.torch2.4.1" +ARG TORCHAUDIO_VERSION="2.4.1+cu124" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index f3d2b0727..5fc3fa4d5 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -34,6 +34,12 @@ which will give you something like below: .. code-block:: bash + "torch2.4.1-cuda12.4" + "torch2.4.1-cuda12.1" + "torch2.4.1-cuda11.8" + "torch2.4.0-cuda12.4" + "torch2.4.0-cuda12.1" + "torch2.4.0-cuda11.8" "torch2.3.1-cuda12.1" "torch2.3.1-cuda11.8" "torch2.2.2-cuda12.1" From d4b43236999da5314e889544524782fecafe8ddc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 7 Sep 2024 19:21:26 +0800 Subject: [PATCH 209/216] Fix github actions CI tests (#1744) --- .github/scripts/docker/generate_build_matrix.py | 9 +++++---- ...gigaspeech-pruned-transducer-stateless2-2022-05-12.sh | 4 ++-- .github/scripts/test-onnx-export.sh | 1 + .github/workflows/build-doc.yml | 2 ++ .github/workflows/run-gigaspeech-2022-05-13.yml | 4 +++- .../workflows/run-gigaspeech-zipformer-2023-10-17.yml | 4 ++-- ...librispeech-lstm-transducer-stateless2-2022-09-03.yml | 4 +++- .github/workflows/run-multi-corpora-zipformer.yml | 2 ++ .github/workflows/run-ptb-rnn-lm.yml | 4 +++- .github/workflows/run-swbd-conformer-ctc.yml | 2 ++ .../run-wenetspeech-pruned-transducer-stateless2.yml | 2 ++ .github/workflows/style_check.yml | 2 ++ .github/workflows/test-ncnn-export.yml | 2 ++ .github/workflows/test-onnx-export.yml | 2 ++ .github/workflows/test.yml | 2 +- 15 files changed, 34 insertions(+), 12 deletions(-) diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 492d3ed47..08281151e 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -48,10 +48,11 @@ def get_matrix(): version = "20240905" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] - torch_version += ["1.13.0", "1.13.1"] - torch_version += ["2.0.0", "2.0.1"] - torch_version += ["2.1.0", "2.1.1", "2.1.2"] - torch_version += ["2.2.0", "2.2.1", "2.2.2"] + # torch_version += ["1.13.0", "1.13.1"] + # torch_version += ["2.0.0", "2.0.1"] + # torch_version += ["2.1.0", "2.1.1", "2.1.2"] + # torch_version += ["2.2.0", "2.2.1", "2.2.2"] + # Test only torch >= 2.3.0 torch_version += ["2.3.0", "2.3.1"] torch_version += ["2.4.0"] torch_version += ["2.4.1"] diff --git a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh index b61a9d7b6..c9e798a68 100755 --- a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh +++ b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh @@ -29,8 +29,8 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == ls -lh data/fbank ls -lh pruned_transducer_stateless2/exp - ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz - ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz + ln -sf data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz + ln -sf data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz log "Decoding dev and test" diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index fcfc11fa6..3252c37f1 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -25,6 +25,7 @@ popd log "Export via torch.jit.script()" ./zipformer/export.py \ + --use-averaged-model 0 \ --exp-dir $repo/exp \ --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index c622476f2..ca96e6de5 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -26,6 +26,8 @@ on: pull_request: types: [labeled] + workflow_dispatch: + concurrency: group: build_doc-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index 3121520c1..2c1d44fbf 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -33,6 +33,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: run_gigaspeech_2022_05_13-${{ github.ref }} cancel-in-progress: true @@ -119,7 +121,7 @@ jobs: find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 - name: Upload decoding results for gigaspeech pruned_transducer_stateless2 - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12 diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml index 87090e310..4ecc2aea0 100644 --- a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml +++ b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml @@ -42,7 +42,7 @@ concurrency: jobs: run_gigaspeech_2023_10_17_zipformer: - if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' runs-on: ${{ matrix.os }} strategy: matrix: @@ -133,7 +133,7 @@ jobs: find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for gigaspeech zipformer - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 501fae38c..6a3f4eb40 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -16,6 +16,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }} cancel-in-progress: true @@ -156,7 +158,7 @@ jobs: find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for lstm_transducer_stateless2 - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03 diff --git a/.github/workflows/run-multi-corpora-zipformer.yml b/.github/workflows/run-multi-corpora-zipformer.yml index 38f7eb908..84f9f3a0d 100644 --- a/.github/workflows/run-multi-corpora-zipformer.yml +++ b/.github/workflows/run-multi-corpora-zipformer.yml @@ -23,6 +23,8 @@ on: pull_request: types: [labeled] + workflow_dispatch: + concurrency: group: run_multi-corpora_zipformer-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml index f8d9c02c5..6e4077cf4 100644 --- a/.github/workflows/run-ptb-rnn-lm.yml +++ b/.github/workflows/run-ptb-rnn-lm.yml @@ -16,6 +16,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: run_ptb_rnn_lm_training-${{ github.ref }} cancel-in-progress: true @@ -64,7 +66,7 @@ jobs: ./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2 - name: Upload pretrained models - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule' with: name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb diff --git a/.github/workflows/run-swbd-conformer-ctc.yml b/.github/workflows/run-swbd-conformer-ctc.yml index 842691d38..b0178bedd 100644 --- a/.github/workflows/run-swbd-conformer-ctc.yml +++ b/.github/workflows/run-swbd-conformer-ctc.yml @@ -23,6 +23,8 @@ on: pull_request: types: [labeled] + workflow_dispatch: + concurrency: group: run-swbd-conformer_ctc-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml index 319a5558a..e76497ec3 100644 --- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml +++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml @@ -23,6 +23,8 @@ on: pull_request: types: [labeled] + workflow_dispatch: + concurrency: group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 1c37f13ed..0681ece60 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -24,6 +24,8 @@ on: branches: - master + workflow_dispatch: + concurrency: group: style_check-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/test-ncnn-export.yml b/.github/workflows/test-ncnn-export.yml index 5709f8ebb..ec419d65f 100644 --- a/.github/workflows/test-ncnn-export.yml +++ b/.github/workflows/test-ncnn-export.yml @@ -16,6 +16,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: test_ncnn_export-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml index c05cde3ba..646ca0569 100644 --- a/.github/workflows/test-onnx-export.yml +++ b/.github/workflows/test-onnx-export.yml @@ -16,6 +16,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: test_onnx_export-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 659681b37..9eb7e403c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -105,7 +105,7 @@ jobs: cd ../zipformer pytest -v -s - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: path: egs/librispeech/ASR/zipformer/swoosh.pdf name: swoosh.pdf From 559c8a716039bc1f3da2a4d1487292830fd21f06 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Sep 2024 17:10:17 +0800 Subject: [PATCH 210/216] fixed a typo in `prepare.sh` for alimeeting recipes (#1747) --- egs/alimeeting/ASR/prepare.sh | 2 +- egs/alimeeting/ASR_v2/prepare.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 996a1da2d..55f9f019b 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -87,7 +87,7 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare musan manifest" # We assume that you have downloaded the musan corpus - # to data/musan + # to $dl_dir/musan if [ ! -f data/manifests/.musan_manifests.done ]; then log "It may take 6 minutes" mkdir -p data/manifests diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh index 15c20692d..1881cd75c 100755 --- a/egs/alimeeting/ASR_v2/prepare.sh +++ b/egs/alimeeting/ASR_v2/prepare.sh @@ -65,7 +65,7 @@ fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Prepare musan manifest" # We assume that you have downloaded the musan corpus - # to data/musan + # to $dl_dir/musan mkdir -p data/manifests lhotse prepare musan $dl_dir/musan data/manifests fi From 2ff0bb6a884c8f5aafa48551fba8c7d0eeb15b96 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 8 Sep 2024 17:42:55 +0800 Subject: [PATCH 211/216] fix CI tests (#1748) --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9eb7e403c..c22f2edb5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -108,4 +108,4 @@ jobs: - uses: actions/upload-artifact@v4 with: path: egs/librispeech/ASR/zipformer/swoosh.pdf - name: swoosh.pdf + name: swoosh-${{ matrix.python-version }}-${{ matrix.torch-version }} From 65b8a6c730568ed12fccccb244e013f6ae3d7745 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Sep 2024 20:34:49 +0800 Subject: [PATCH 212/216] fixed wrong default value for the `alimeeting` recipe (#1750) --- .../pruned_transducer_stateless7/asr_datamodule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py index 6b56c8a6a..9da820315 100644 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py @@ -82,7 +82,7 @@ class AlimeetingAsrDataModule: group.add_argument( "--manifest-dir", type=Path, - default=Path("data/manifests"), + default=Path("data/fbank"), help="Path to directory with train/valid/test cuts.", ) group.add_argument( @@ -327,9 +327,11 @@ class AlimeetingAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures() + ), return_cuts=True, ) sampler = DynamicBucketingSampler( From a394bf74742c0242f35a514e016df74d6ba42505 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Sep 2024 20:35:07 +0800 Subject: [PATCH 213/216] fixed gss scripts for `alimeeting` and `ami` recipes (#1749) --- egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh | 4 ++-- egs/ami/ASR/local/prepare_ami_gss.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh index 76db19832..bd25bc9e5 100755 --- a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh +++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh @@ -58,7 +58,7 @@ if [ $stage -le 4 ]; then # for train, we use smaller context and larger batches to speed-up processing for JOB in $(seq $nj); do gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ - $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ + $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \ --bss-iterations 10 \ --context-duration 5.0 \ --use-garbage-class \ @@ -77,7 +77,7 @@ if [ $stage -le 5 ]; then for part in eval test; do for JOB in $(seq $nj); do gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ - $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \ $EXP_DIR/enhanced \ --bss-iterations 10 \ --context-duration 15.0 \ diff --git a/egs/ami/ASR/local/prepare_ami_gss.sh b/egs/ami/ASR/local/prepare_ami_gss.sh index d5422458b..414c22b12 100755 --- a/egs/ami/ASR/local/prepare_ami_gss.sh +++ b/egs/ami/ASR/local/prepare_ami_gss.sh @@ -58,7 +58,7 @@ if [ $stage -le 4 ]; then # for train, we use smaller context and larger batches to speed-up processing for JOB in $(seq $nj); do gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ - $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ + $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \ --bss-iterations 10 \ --context-duration 5.0 \ --use-garbage-class \ @@ -77,7 +77,7 @@ if [ $stage -le 5 ]; then for part in dev test; do for JOB in $(seq $nj); do gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ - $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \ $EXP_DIR/enhanced \ --bss-iterations 10 \ --context-duration 15.0 \ From 329e34ac204bfedf7d4169ca4ccd295de7ff8aac Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 Sep 2024 19:29:19 +0800 Subject: [PATCH 214/216] Test export onnx models for multi-zh-hans (#1752) --- .github/scripts/multi-zh-hans.sh | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh index 427d8887b..e254419ff 100755 --- a/.github/scripts/multi-zh-hans.sh +++ b/.github/scripts/multi-zh-hans.sh @@ -16,6 +16,48 @@ log "pwd: $PWD" cd egs/multi_zh-hans/ASR +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2 +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo +cd exp +git lfs pull --include pretrained.pt +ln -s pretrained.pt epoch-99.pt +cd ../data/lang_bpe_2000 +ls -lh +git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model +git lfs pull --include "*.model" +ls -lh +popd + +log "--------------------------------------------" +log "Export non-streaming ONNX transducer models " +log "--------------------------------------------" +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +ls -lh $repo/exp + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav \ + $repo/test_wavs/TEST_MEETING_T0000000113.wav \ + $repo/test_wavs/TEST_MEETING_T0000000219.wav \ + $repo/test_wavs/TEST_MEETING_T0000000351.wav + +rm -rf $repo + repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 log "Downloading pre-trained model from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url From 6f1abd832dc290b62adfdd0f615010c2f3c274a5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Sep 2024 21:04:52 +0800 Subject: [PATCH 215/216] Fix exporting streaming zipformer models. (#1755) --- .../ASR/zipformer/export-onnx-streaming.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 41 +++++++++++++++---- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index e5ceb3683..88c58f581 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -74,7 +74,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from onnxconverter_common import float16 from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -756,6 +755,7 @@ def main(): logging.info(f"Exported joiner to {joiner_filename}") if(params.fp16) : + from onnxconverter_common import float16 logging.info("Generate fp16 models") encoder = onnx.load(encoder_filename) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 69059287b..2a0ae0129 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -191,6 +191,7 @@ class Zipformer2(EncoderInterface): dim=encoder_dim[i], downsample=downsampling_factor[i], dropout=dropout, + causal=causal, ) encoders.append(encoder) @@ -198,7 +199,10 @@ class Zipformer2(EncoderInterface): self.encoders = nn.ModuleList(encoders) self.downsample_output = SimpleDownsample( - max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout, + causal=causal, ) def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: @@ -1217,11 +1221,16 @@ class DownsampledZipformer2Encoder(nn.Module): """ def __init__( - self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + self, + encoder: nn.Module, + dim: int, + downsample: int, + dropout: FloatLike, + causal: bool, ): super(DownsampledZipformer2Encoder, self).__init__() self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, downsample, dropout) + self.downsample = SimpleDownsample(dim, downsample, dropout, causal) self.num_layers = encoder.num_layers self.encoder = encoder self.upsample = SimpleUpsample(dim, downsample) @@ -1310,9 +1319,12 @@ class SimpleDownsample(torch.nn.Module): Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, channels: int, downsample: int, dropout: FloatLike): + def __init__( + self, channels: int, downsample: int, dropout: FloatLike, causal: bool + ): super(SimpleDownsample, self).__init__() + self.causal = causal self.bias = nn.Parameter(torch.zeros(downsample)) self.name = None # will be set from training code @@ -1333,9 +1345,18 @@ class SimpleDownsample(torch.nn.Module): # Pad to an exact multiple of self.downsample # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds + + if self.causal and torch.jit.is_tracing(): + assert ( + pad == 0 + ), f"pad should be zero for exporting streaming models. Given {pad}" + + # If we are exporting a streaming model, then we skip the if statement + if not self.causal or not torch.jit.is_tracing(): + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + + assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) @@ -1609,7 +1630,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module): k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim) + assert p.shape[-1] == num_heads * pos_head_dim, ( + p.shape[-1], + num_heads, + pos_head_dim, + ) q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. From 5c04c31292b87dc95fe0a9b498dc753f509ebda1 Mon Sep 17 00:00:00 2001 From: Yu Lianjie Date: Fri, 20 Sep 2024 12:38:52 +0800 Subject: [PATCH 216/216] fix open-commands path (#1714) --- egs/wenetspeech/KWS/prepare.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/wenetspeech/KWS/prepare.sh b/egs/wenetspeech/KWS/prepare.sh index dcc65fab4..e52e1a9d1 100755 --- a/egs/wenetspeech/KWS/prepare.sh +++ b/egs/wenetspeech/KWS/prepare.sh @@ -63,8 +63,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt pushd open-commands - ./script/prepare.sh --stage 1 --stop-stage 1 - ./script/prepare.sh --stage 3 --stop-stage 5 + ./scripts/prepare.sh --stage 1 --stop-stage 1 + ./scripts/prepare.sh --stage 3 --stop-stage 5 popd popd pushd data/fbank

mO(8d|>{* z`$hpp#6^iy-hhJLsw~Y6Vz2pJYTEwpjo~R@zwA5mnr`2M?4B;KO<$@s0MCef&UY|y zp3K>NEEzO`RY6Se4f%y}%~kR}Rs*u>`)Oksw1?On)6%& z{h*YN1Iq5t_3zV$oX&$EDZZz7HTnS|JZJI_-uvT*k7J1KvEIjtioTB(hDE8!!iKw` zW0al1M0yAa$SJDzo=vwck0GF!CV2j@o~PeqAyb`&VF80_+#3zMidoiw&4V8Le>Kmk zzl^U67+*^o8Lf7IjlX|_DFa*gu_`5YS4UCoYZmg>>x%HlXo|1St|#bH_-BhbACFl`ZW^^4+;D(@ ziw!sZ9WOiwqmCf%!Z_sIle}8^$|4xX572s2D{>npB(HpoDzpQLAqW#KB`KKN*3#}) zVgU|x`)o=Tt=_&K@%hJ-9_z+7gf!)!#;HjCe@L$MgB&dsttHs?d{C~4M2VJg@ zw&i72ro$QfxZlo58H2GIt*jFwj1YiTS}Bnxxcu^t1l`yipA*%EJQG zE)xsXF|Kt;8m0MgPz9MTg@o1Q+D18UsV6(y)zx#QYAsjDKa@8a6wTR8sXr*U&ebeI z1?zEJ$TcYbZC{6@BN5_!(o zD0#>PNGSlPqrkzek72+u1K8D)U0S2@k6h$ z*Yqhi((Ay!jYsSn(2oRk*7%KZ7d1?Id;*1U$-{83^8y{urGod-eiJ`T6%FOaBU(ez z>)*2OR*KN0IGrK~&0kYAzy zWeq^71RNh_I7n^5wUc)Da$9t@-8YS07b5P#-iA>!fG-TQYL9nifHtEQLj7UnO8H?R z?^y4On&{O7>WX2@6w3M;n47i{9!Bho6A$Txizs*XH3ZbZAGLHnmKMYfv5&@4JXBn?KS-U1w*0phRz<+uuHPnF238io7eqI{jG`MjoH2 z(*Aw#*0ara=Uj3+6}AtgVMS z&UYa8THOJ7oX8vBssO%=oBm1)M`6!cAnYLNKI*%Av;E&svtSB4GG-18Rf)katW4rs zJ9~R1I)b{1cGoKki1*weB)x85L2Fi#X1mq58*=1(6%#izG=Ybke(l;y#YfO&vzhOk zAmkQ)#_(VP9!6VGDq3N1B^Vg(anJ+xXn0WERi5_!AYlmsz+gpb5!uW`xq45aKlyOy z#ZOywp~r$ap~(3c&oLX6Z8qibn=QBaMJ|Wcj}h`X+tvCe3p7HsuF1$H0kBc+y+2_@ zS5Ivejsh>aL2aKzHJ_%2P6*Zgnk!rOO}7Xr?3E%bu|Q`F1yBh&3d>XI{?Y@E?9UXA zbIpEmGicRrZoUF}WXw21dkt%gD@6^VH6y6~)j$man7%L{<8xxG)`uUS@$f?DKIw&~ zE-8)!j{Oq6hOU_H10r8uOyiuzTOsR|a{khj`9Wp^g68{FE%a;;sm7#DF99jI)2J) zmE-+(^LjB)W}F9_2^Nwi^J#kwW+P3PmX*}J+}(Eet!CuKT;}2`z6(Tt*dTvWfH{T0 zP9R4q1ScDwT`g$i(DjwzxzlZdA<_MJ2|@$sq)~=wO!{OR;Rmg)LlVUxgF=CwY>9%< z`yPtLZ;Gbnj^U4oFB$lAr~dYDM?DcC8>7@N`ud;+MFq*vDgUC8-IY?w3aW4tritIM zWEv}MC`4QddnmrR$;STV)IVm=RA4qjlU0p>!JU~bY-2v;oS`!l*gO0=QF>Bc^{D65 z(nT!Bd`IX+MFBV#vJChdm&lrgi@PS!)Y8Y>xc=>yIZ*@LE||IR84soaur#5IB5D;y zo}y8vExF+Q+1#4?H>qC9aj@#Ce}2ymca(W59FLl?7)J8e8y_ z5L%F+G`1CBgTx7SM&)nKrk5Gzqk{P>#23#)xHP!oL&ZK6X9bnCKrQ&r@>^g40FcZ+ z%I-YRfd>&jxDCx}Udl?k>jjSTMU}t#Cr#S*%a_^@;X*9ns$SHZ((745KgzYNHoyp8SG;#E_=D`}`;ejhj#5;G zpmW$*2u6xm^4{gh-W--0iaDdP;Gg?g-rqgmySv)DVbT%&z5hzAXVoZmwBa;+D=5Yc zS~8%i3%!+;ge{fQAe1Rl%~mp(Ey6|HS-M+SHpM0|-R+5WU0BT9A>|%&Ysc&6U1mu$ z>?9x1ClUXSRzMjrac6nylgjvo>}~H_DrDr#+GT+K-l8Jqyx-kj>R+VqvzFv8TECv2 zf|%s(ym*jfVJ2?vTWNbK5e4_`>EIYJMxtmmZT)*t78JBZcmv!)p5ZZK^)}MKyXtnR zzxqp`eN70=pgyMr=OQ$0+k}bmBlX`BsxQs-`VnImR)$YUo0LxYOtWRHG0$(0kyK6I zP0_O)t;wkf+5uT?)+G%a+x3?ZUMh=kCS}bwReLr6{MM(4 zY#O?wR>ZmrnU9{=X}>H15uGloD$*&NyrJYNmu_k`=T(VXtRff5}+HE7(U^XPNS52n^&N!4Ef`nW@- zEfCs-`O#L{A;5d+%{NFa>`ve`EXP-Y@Edt9?igi`(^=W~{`5Bi3wN{I^YN;0aAQGr z;*ai}AE73KtRZD7{KqYuPTOwHlcfqiRdZVHZ#RVO8CQOLPi1eB=&=HigUo!am{9meo&P*o4jEVWX|60LDcik8n!*+_R`+C#h%e>7h%q5R+^FJ{-}X~pUq zyq*%H_9xnE$N5bZ&Uira8JhSD(T~H=i@1auTP)9OU?WwJerbQ2uz28Un2K+AK?=a* zn#T8YLBU0Z`8^7lKsye;OQOa`A3Vq)YhNrOZ-4mo`$w z2-_ZjlVPn+=0HfOp1>K`oC3c#O6EYxC24DV3RvKYc0@@OJA+nu>}_> zZm?>I*2P8|3IL8AzTODdVh-t{@@l9PM^6$IAw>LK4x|$KS!l){^+lErlTpbf8!h>& z7|*i$#v8?K?@)!BWk88+hF{+!f_%avh_WOa!rP@-3r6n+dG)_-<%MPvlXmFZA5=6~ z_;^OFN*a%3yK``9IoB{s2Fj!!MLY@dLam>{r2Wmhxi7TUdU+eAxvc7NRN<{$(c+Dv zBw=IZv@mgeBKewukyd$5f@dmVx)h@O#z)TB43Thw(fy*D9yq{xjxHLAJjtsI=gQ>k zDNPlH@ik8x7&FPj$6|gMB#mDG)@nMeC6qS<5y?E1ToD?^ing}s?u)C4Y>UY+B_wz? z2r&3iM7qlAT<67phQ393KBM-ed=4#9Uktl=Lx^^L$PU#J7{wD}Xg{VtvcT0>dZLS~ z2)gmv2EW=}dbC-?fE*bVZpxuD5(RY!Bq>d4wP%o*)JVZpBIBTXlt#JUYHerFSCOI& z3fz_4UATlau{s0`oeQBgBOJMHi2LLs{yX9lIXd`-^n=zx9pk)s~U(*~^c zyyIQ;I@5u{o7R{_boMwwYW}Qf!Y1P5%+yX7wq|2#Q}MbVesJP|IqI-#M0e_1KjlN&ZTRR&KUAUDpcjUoRLaBW^c ztG|A}7s)rFbm3iMtD91q0*%;2jxyV|O@O|Mr0BF*_)&OkoOIY=QGFy!J)~iX+HlcC zkdQLtqZQKt1ro_Vr?52WFE9ssSn0cE;!yO{?+XcoWL9$43}o1wWf|e)m6R{K6#$Dl z+AzddI4j4e_O#wQ-x-j?7j*MDs_1)%5A-noUN(Ric;*MP6IvbI#hvuuM<*Yht+t6x zKVc}mP9%7mwrMk~Yu<`dvgWVOmQciW=`=DfSNE5?o7EU<()nt2bYM<@-@W47>o3coe@?I@vxtJ;{4|rc&|wyraefZa=O`P?*=tFh zzn`95+LuzVJc=^hu#8%~Sd$bMq2MJOtrMBU)EwCwjl+LG3CMH6V(o;Vfn)yY{Eun@ z65XTquoKK6Bq9hh4Z+1IqToF)BQ9LY=a7#vGb;S!aqhA2nawvLzZAZEdSX(JQO5F@ zoCp!5ia%ONFDO9XcxXRd!bNO%nj<2)N}r!0ZAikg~4u3Sng$c zMe?Fg&H#XGVN6h~Ls6$I&*{%sRn79`w0<5NxLT&Bp~G-72ci0vB06mT%RP*;NAP#X zi&d@M-iT62^KPkA56C+PhJzk7g7gyLbCF94n6W7d@saHrQoF)HYF^u4j4z!xt{JrllGz zC&wo=*-7AoBBtL+^JW{brnZW)Q-rbLSER33>PN$p-^I9rRMwJpE11r0B_p4Kxd5h!YDmVSskn?iyq&m#%sW<+ zP+vY!JO+6je$3zmM~X2~_HLkWFrT_v)w4$=ZHhjNXEHJD;I!fBb^H>W06i!3_mVc( zGEVq`TXnZ@%WK;?d4hK(S5EooAokkZK<5TD-{+|}N_M&Y{q$*p-SXq$hGjjfYAVLa zyy^0t=EJ%Nhd|EHr>H0uZp%)H+q4%$lZvW|`jmBMA?XM~;iuKKOZqrO-n2tb@5%Sl z-f}t4>EClB-6p=HYE7=nh7k^VcJBq`cH32{s^SWD#$5kSDZ(Iwj-AN)JxILlkgs7>vbjb(|-?=wjztx8N0 zhcI2>e6WQhktCwEu#Z{xMrcAy{?7S?>>uS|r~)gSu*bbc6vgGkBt8`-f!VYq-~skX zSh$%Y@pQW`0*!lcH?_GW^%OiV!0y41@T69FsVv6Ftc{egQ_;ap%;Wj#bHAQxt&H`v3#B{>t$fFsY7W>Jo!DtgHSIA|mp#1bnZ5%JI~xS0EJ_ zo#;nY^~wV(B}1G*#o-r~OiL-0@W5A5dGJ<MZVMqcy z<_OXz&nT@d1UpgBx46&!Gq9K%a{D2EHmFaD5c?BCqSe6Axy#y4gVhU#=O)?yo+o3l-mT%eo(7B zD`D-_Mz_8*nvtQ5UEF9Jv3V?u2(ONDfXQ@-;`3CEceD_SdwkT4yV6_5*-4`uFagry zSt-oRO%<=lv=kJ&l2kagg28k@vZBp;*i{})YRI67ukJArzZw^{hFR5@T%nQua)IftYvd#q6hwC1Y|zulG6M2h?r`-&Ho;TkhntVzNRu0W$UB|7I?X{R&NN zVwh8o1C)1VZtyl;Kx46r11BJB!Q9EAu*)b!YoZluWk?zL*r90Zj#j}7OhI^h8b}*p zCDzVo!Ku;z?FnO<MFB=EF-uSo$uek;dW$7e0YZ1&Q%f|Ms;f_0RwUMny2p>x zm7WuqUX@$jV?JldM)(pzMnU!|`fa94HOFP0w7U^vm-jRB+8hHh*D1;37YwuLWH6L- z+zQ5roIOlS5lbnT=U^GqK^IdR(+X)F?#(kRT2GAMf&ml?gp-|;R0m!!LFR(*BjW6b z>HtSINQL7<_QsrIS2EaAvW00puNqY)3R5cGz7emv{w_GF55qcB;=yu+l*WbzJy7@oPX zWqe!yy^&V>RmtI1GVH%*)`MXr*Gl2PPa6sL2pfHTAVQ0Io_Gl&;*6)<>u|T?;1ZCQ zKiQKjS}Kryu|2Rx?r@_K#G9TNZODT5Mf>8U7Rb)xFS(vc&1#mrkVdg>o54{16f1L{ z4edzL)Dv(7KW0yxg@w?hyF&i_Vcko@)A5U^nTbMMt5Ja&@m?m7)|F;8?%(;t&rCW9 zOMcu6KVhNvkiRX#4)?x|bh_%H;@P}j-jg2Q$)u+0i0gy3U$!|1#DgprAPwI9cI(_d za(uXcG`|N^I}%5);?j1_53P9C6Z&w)Wp@Pxap5kZ!f0h5b*(K_f^;zwQo?WdhgZOX z*By4EaM83?&Vm=?9;u{zv0}KehVUPn>fXRh&w*PXEH#S{P-y!4XW_3%8c4cA!!}8e z0F?Awfp>JM&ns|RP}KJa;~&VR_uVMQGYD&lh$24@SvlC z8KZ0`qrT@hw8r{ruiw&mSvTF~SWjKNYQQ#_@~~b81UzCE4;wDSnVtOV>xxz#bXijv zGPLh*+!Pzv?Q8MqUP+Xe4HeH6>`?c;$KURF!rvxYWmTVb0D|R!mNYbtYkGs%G38&f zHVN|+KZy(}e~OS^cbr$BSGOKD1ED-&^w#JBlKC;+cs2ZoNkFV!3c;Ta;Y2=$IP@X5 z_q_n+9T`d+#)^gOJ@5w^H6a#?cqAgIf2{FvT5vs)y&4a|kvtTvsCm8*Y)YNCtXn3- z%%}yzJU9V4jf~D^gAPY8AL;-a#sr=LX(2l*u4iylf+Q8m3FZLY6$Km#C76BylYC0h zqIpR!ulF(uDai;bT93+(w7erd*kAbk{k)C>PljGSez9{i`5An3rF8s>M2jAh>Rt|@ ziYh;-wTZ1vX@H+S`_E(GNYc+4miNYYF-aO6L@M+qLR138Qff0e3=}{t8V8~B$C#23 zJ)Ncs9CVT<)mjRyGoz!pagAO+ZZsysk9MqCMFq2cv)>1z8+>q>8-)ICaJU%xItT6Q zc?gB*%@*=@=c7b7>NVWsN`E176ikxO)8r5OJPcCG^c-PY9b z+qeGa!yF9J9V7z@g!&w?2toTZ60Ia+>7U^xpXbyhtK};fv04Vw3>5?wEsvWEjZ#7r z-?!r?C>}&^(o0tv569cfcO=kvFG+S{%S`J zC}F;C@p-i>*_dfbOwHCOKHz&7_{wLuJSlz1sNLiR&J^5b*J@{ z`YpjV&#YVB=G~5e5r^6cUqhE@bY)78r}XBs=RwH4Fy4(OVv}Pk2+;Q&Y|6uk)&gb~pBOdF-yklY+Q<>(`S} zYTozEHl@n>l)_6QU*RlVaeKK+;}Qnn%4oz`$J4)C(&`F|Lu(`d)V<&JEr#^`U_2`5 znomkSy|Y&BS~7v{{oTB7|2`h)_Yr#S>jqSu6D0mNuoO*9iS)Ynt_`ONvTqOdqMTYZO zx8rZNbEcuN3hyP}z6s^-*u$*Z!6p|OWhG}@{GOILrg5Zjk9@Ign>Eo%YlOq?w9zQH8@%HzTx=-ZNnjuE@T%W zh}0Bi%*-J7k_RP-fel7WG(=NbIkPhUQy~X&1Glas_}4$3+Fy_5fox6!Y{Y*Rb?ArC zkZfw4Xt%PO-3b$iOxJcT&Iu_5M_^!6mq|73eD-J|v zT6%z&k%a4NIEO5#SK*TLzXHN5ODk-(Nh}0SiW_l!a0Sk~V=?No*2;el;2(Jn@I#Rf zqH7?_{(7Q&o)pR$MSq|$e^bezTEe|Jx;ZLq4yxo6C2xk~TrEd#8)L2U5o2*I6ZAT3@W84X3ONv$3dRhg=IC&@~gC5R*Enh1L`Cz_oK%)W}o zk7(=V$S)15;o^kfP8F~*abxHvB!Th`n*PdD2wgi{zr;@Hm5>zOQY>@y+ltaQ`pcwq%dLX?uVac66Js?*F(~;C{`0uq4l33@T;>|v zxT<_9tAE>i*{B51#>W$(sn;BygB2c6A{Jm3hZ?H2d}jt9n>kUR&M$h$GjYG+Km z6&uP*uN{(<)`wB=9hAj=Jzq|hroa5@O4&WzxMAoABg08JkB&7 zoCsSIpt}BpdrJQ2)2ZI#4!-th3(EQ<+sEI!#}9+G?sF=6GQM}_FJQRIW81k%|L&6z zFw_Qx0*D@sC_Je5GRf9!&OY89@Y&WGl3ihL5fT^vYW#8W&XmwHbT(li$V=}xZP&8p zxxMK|{=L-Fg}HEDUQyi!%p2H^3;J<)Jsbn&J*wz;rh1OZ^g6d6yMM`znLgp_RBoYl zSj~UE(e`t^s%67zyYqR9J0myVKY?Gq&K5w{ev507dAqjU^>={K6oc(sa5FBUp!ZMU ztB6*TjWRVBipXHrBKgM%0V+5GNUx(+depXGJahtkYNiTnF+u7eGY-5VccX0CB44@< zu|Kmj#F=Z~n}qRUQmER?;KC!6B`CSgZ(4&&fQ$odASl_rc2Ip}m^n5xjCmrK8PN6l4WRn5fcQwg&@Sj<#Ne<)YbEBWSvXbH3=?MO zNu{(emW=4I_y@;8e@!R0`Uo+7_LVl1{?{;s+hq~21&RQFs{qnknI$_nm_BtmxlEZl zY$ZWBJMmWu#1;ArgB3z>SElvFz*dX+Q9B>RjHds-JYIFVB zh$_L{>d_~D_vwTr#j`Mh%l!qvL=`r~@pngpg48&u^zRS_Dh}uspXDSF$Lp{WJunS~uA(CinGA8p5%%cfnJD2|Jm^NLMy@!CAVEc-e_Jfz z9nG|*a_M6D05YP5cOG?U7xvQa=7WY=_Q8*6gBHP%rtGy=8TukIX2tTvvJ*V3pdtdS z*x5(^b2oD|g(|$VKk2q0epjW0M7xRh%a>7_4QMRwS{x3hz+mBAtbeieuuV|H(xyQ$ z)TF7F`e|8FH0@%~1#L2PDGyAV(`+SQ^Lgb9pSAdG7CgN>Tdki8Tz_WM()u;@YeZN% z!)>}GnN3XM$XPcFuBw`hQ!@L&SN*|-%YtSzQ@JFdQtX~TwRMhBQ*U-oov00AQ$xhZ zlTF5SUB5l%Q?>fnq)=&Anbyh^yj7uk9YORADW|Kevg+D3sUp7WDcfCjhM-x8J8$oo2TRDdsvX_nM zS%4>~td&}+Gqkd-evZN?{^29dkb~OJ15Gq`s{g4T<^Qf8wc?-dRpW|*M&1yP3zuy6 zBXa~w&UGSg$Qke9sjLqAyZLcQV65DnEOjrd-$`EB$c@+gWs4y&<4K5#x6|oRqM#R< z%LbVjBDUQ}$ZWHX6PiGTAF~+3*~No5+xLvNM>Z7(2(Zj@wkQfOJ9!V!UGt& zKF>~({BGV*jIpYPMCVTwp-p?Lhk~GDI7$`uCrH{W9Hm&P4^gh%NDfq7yh*ykc8pJy z4eJ>$f2wxU3ZqVmy*|fg@=CZLwE-cGrc{!<0v)1CYBGLFdB$ouhDuQ@N#P|dd=Xd{ z`}62MA;D0{<`yACVBu9&G)p)ea!6>d8=HhK7EzykGKQeuHS@H_*o(I1iw5mVw$G8N z+uMY|4Qm@)sWpV<_Y7T9P;ot$Eas{*KIy;U`~^wIy%~92;OX1CL|!MIPJt>vla6CX zGsra~#4{0n(y4X%EGgnw^TvRE_X=*(&JVOoLO@6mZlHw=cJODwZ!}m1M%x}GBCt~k z-+mP{dXNDQ@v}QeKh%D{zKYj2Sz*zuEQ3ZLrWM?T7d9<*%4M^e3rPBM&_d-8>v~Q) zXc{TU6Pg4jS2I5m(M#8ixXj)K2?(J64l_GF>Zrmb=*qqVfO0tai%KcK9!!Zz-~{Cu zNnwQvMQ4+GXLE0S!{LftiuNOY#^o#{Q5r4wC(7ksbpz-=&6x4@By#D;R@e^g0*w)0 zB{vUKK(p%Boa=GY(Rm3?I$RhXo;0U zMDR*GcDgV4%%46mA=C`=V~`*tTsa9ox1$wrzGQw#|-h+qPNh*yto3bl7kGp69$@&iMyw zU$w5i_MB_3G4AoC&u(9N1$6czltOvIFW>7Ja;}&baGYu6(<&PH4wd&~55w4Ms(odj z$}rKslY9ryYP0cc)5E)~v=`F_nTJOu^t*%UCd(pdT|xHLxRyDt0?cpYu?a$K&x5R; zus)P+HdZ7=nAsZxUr+5o?x<2^AdT-y?%o5=d8)7A)Wp$sB&JKe!|Ntpi_hNKUgO2W z-T=b5Lr;Dn&$5%q8tO<|#q-X*>3kgD&zXDyudf*`ov+=!kGap6UDz;WoY%;pzme|l z?f@ZaK#v`&UzZ8z4(Bdy9Ey6t(|HS8Hkwt-tzMUO}96 zbv4qfwVzFv(8ya8EbIdw&Hh_mRx&YzgN%$H3+b_HY_0(|(Gx6*qjihTJjYdGxn ze}4WB`bU{4_)rJpAnxpaad(!k7&vpHIjvsJ9w?>kJ0WWDgKvPjJ8G;QOP6A>)s;$ds@nTmb(Xa7hPKrc*KxM_{wbUqxL-=o0##FM7WeTl6b-;XkZ9hYZf|(obrkz73=+KNI z3Rj4lBAgJDveFn%el>=O`V%YoSEnWs6M=S9G=7(%a@pj4?`vdXhsfhxsa+-x=ahxC zEY~l8v?>vB@G+dY%6#M}_11}CbuVSic9l_~Q5EVEs$1dhu9`suHEDYaS%k|vD1q&6 zR+ zP|-S-TK7R^U9G=3Rel-9X|LDy=y$+fsCDk3<0t%CqIp8Tt{jx@aq>cu8$5c4S%ks= z^<%##6ngn*HYEf1|s7H{{HdO z!{e|C)-qczP3?Ug__7ytZq`+i<+vUpIP~r5cYtS>h^0733=VbTg9DW_E$M^S;E{=q7V$n-)WuxlouYJ@Q*G*FeowF1BU)v z&xW4!;@Cmz4u`Gbw4WcxwC%|`VH@;QHUJR_0sfC<+?gCM`$3+YV$dOC#BFBBv^xFEUQa@F+Og!PF4rD4vgXe50tP zBX$_LDbyTQSxKN-IVYc80PzTlLC?nx0-FJs9TovRPEZ zGR=DqQvz9&ArG!Q&oC3utWoyFno6LQe!;A`JKfK%8Yf8|9a+;IWpGy9(mAYF=cG$8 zi%CYKEQx>%shP!36mRQ6S(HC?9Xtg?4jHjQ4WR>!SlqcbLJKoA{>X(ymD!fv9&(V8 zd+j$a{U430M^jOT6r$sV9$0z|xj0x@2k}=zBtW060X(kCFE-N|Abh_?>Kpo~hIi8M zJ4IMs?198XhHe>EnKg72JyUB~teHyoHOg+j^GsTWI_WMQ<~IGCOM{l`TvWu2<#UJL zafI@yxxu?9IE#ZXx4of;dy#*+jxhS(7pI3x>BOvz7P$bAjB(sS7ocR4z_O4y&oSv* z>wWL=A2s{l!Xx{6bQwGiS7Ev$yI6PFnnM53S3ZdjCP_xxVl6G`eH8}E>IY>f4H_9G zYvVxIP=BF2859{!dsQ?VyN0SqE!-%i0z3=ZS~j_R*gLurP8UR?S?wK@Dr(WS!nNi5 zLA*2wbPOsKL8RsQ6U=mgfbtJ@>0o@?OEzgl7lt7saEjLo%3bVkA$D*hWrKm@#<5XL zDa+1MWR?M4yvQ2iv1EkQ6}6u5dlv=p=+K1qA73J~jw=F#jT3CG6!gETBE_Vp=oOwz z%_y9Ts4E1(74^oM1a{}kZU|c(9Op}|E1Kf`51jvOv2!ditXG?+IB}l+$p2sq|1r>z z8vRwq%5CQ_CsYM7?vaF$^TZ}A30N{!O7gMgzz2wg43bG<*pUF=p{% zJsVKCl*4xUd2nshE>uNOTjw;%^&y0$AyB*bGk3e)s-phnICZ+^WOp=~fSu_95~Qk=_$TJ$h%GbS^U!rr=$lS_@GnW(??>UC*=Pm-*9lrd`Ar%D4Pr;6I>w; z=cNUQb*B-kD0k zF-8-U~ zJ2=IP;m}94?hefcV;|cg7QRp)2>4!HaR&XW#I5O=$7Z2hDfqYB%Bk+YOa7{I9RZ9{ zA1!9`6gnYz($$JzXfp-6U5idFBvSXdkq7h9OIp&7a8&hL?L3>Jet9#Z0%j>KCK~q+ z3C+4h%xe-%6nWmN0}244B$Ojo`l(zsCgvz zDi09Zg{G^ihJkdNbmI@C>CFj577ZH{(JJ%Y?$ z5351)p1`sVi3%@F}=~EBCVpPd>Tm2+o*UlVYL;cCZo24t}zZN~bBmo7q5rgDDoen`q z@}B!~(pn*(8<1;Mhi1u>#iFeU#d%5BcdmEL zP-4%NM3J|1>VJO||1)-E;9i&~(;fZB`e5ZXMw=#E?!OpDyy?%Qc6y2f{RtjwiHB@L z0;>_wed|EArkwYC#zEM%fUx30BX;B~g3w5LC-t=f%PxuG$P}-AKKD*)nukV)!yS#+ zf}uABjQpg4etR!ms%IennS+2FqUwVpKD!$hm}-i{hSi7*NBfY0O=4WaR3iI=;7`^+ z!^z@VQ5w)fO~!ptOyJz~&Ls@a=LaZ_EI0$&J1h2)kG+4#;@kVuz{sLceNZW3LdsdT z5qS|~)rol>wlAoC*3~ zB|jAt#2tc{HHoR!QN{_R9N-4y;f`(=LQj^^265*fu_5Ao!6Xn~V1hWs+Dz;)N*wI-xbOLmD=eP5QfP7vaj-2ESkv$( zUTEJyvdKrqUFMjx#VJR-)#2#w#a2V=zl{(Qd~RSa*CpzZP8!kfcCc-8yA!gD$97_L zR7HEDE_AlJg1Oh3*nCYyFie`!F1ZfubZLxXnT*}v^LHy}uEfdAb>TzSq{fa$dXFPo zhYvM1f1YbI2)P+^ThCyfU9`b#TF+?HPL>Q|-ot;H3Y0>a{9~jEsp8iP3(5BV(5A~Q z(aTrb-iW^bN#-dk#qF-+^)$T8%pM@OIK^(F8Qy|CJ(HZ|-YFyQxh3$bg8ey1torYe zQLXyBR>6in+?%zxG#(p0bk?NrI=z3t$fQfag3wDED`vLjoM$DgOw()TT+qcX}JRvytY}0{Xp~63$f}!@#Q@eSc`3>&SbULZ!o?-Z3`Hn z*G{B&?NNSCv|y-VsxmzT-GFsC?zOpZ@SWv-i=dQo`>pHqUzvX{f&EK?tb~z+njG<; zeVN2(KbY0nkDhZXR`~ItWr0ti9qs2M(0nS+AQxUDtvu2;N~#Izmy$V3Pv!NGr&7cliFX z+DTo#ekt6lJB;(iA@GtJVQT$cC^uav+54uzKg8pVzheH@lG*-dt^uTW*FeTw|4qZP z<#I(c=Pp5q&z*DK!KgkFLBR2`D9-#>6UkL>1WD2sPC_9j0on*a0-|8JJ%V#6GAD=Y zfV;?v7LSj$z!s012UB+4aiBk`D6(i+5!hHAoRm0|4(H<}wg+9j=m&K-pfRA;#*k`J z<8y{hRU$St5@J!xGe9b1!%>%SKT)kIOizzaxK(OXHB@tdlb~*=57WTHW`%&m95s&{ z#seJ&LiYKFVd*U?7gMA5qo~pgoojR?VU?g@bw=RDf@d*26ZUmBqYUZ+Z#>l3sB8(x z>7ye#thyAiW?J;siaAip@SWdFTv#w>Zx={ViLh#->tx z`fO--SfjD&(UU9$w6UPcqqh8 zy;$b-0NLtpln+U^?u%i-`B|ahrWss{jl*-XVI2ggkaiY_miF`=rz4h0kE|%? z7Q7Po^e-~1#VqID!oE~0Y^^X3CzfI*Nt-dht=MRYIcN2Vtw2Bvr9)^tE=% zG<0bt4CQ_CjFmnfB-N-$ApK+nn6D~jR~^{O;}%pVkfTx82H^4yUXyk;$HXx{r>CmN z8DP>VrK&!ty(_+~U%U)R(S>SD@}Hxg6${Usq18j@;>~~#_AXlzZ0ZV@4}*Nd0a!0B zz@3sU7!9(s6QCJP`Wh4MG`kt3pn(xoQ8eR^Uq|Z&Q$ePwp0R8cNo%L4ruq1igdL0& zN*dN+rUay#i9r-Z%dt-K4NfHK)}&XKbnh%H4pDc>MU{G4k@zV!!73h_tb|?-vX~LIX&awMRzbr7aN0Gt{g*ngU1-yit- zR0`MoCt-)R^G^IHjjeFrX|wE4n9Qd;+S{ig`%Z2FYMSe-XeFg|Z;rTe&90)j@syT~ zif`s)n@I!49U1mi;UK}ngqumSO`{|mK_WrOF?~e=fyR}`+h@P^M>}iklJc%jM6XDi zT_UT=j#G42@7~V|H%1XfIk#QQ7Yn?R%XOspSqygn2xk83-j$>OXPvC8!Ccq9Q=r}c zSXCyQj~BZybsqSeK=Cdsu2jK1yVEYLHL+KbBX6LN$~*~EK`5!Zz-6J}p@BueT3LQw z^;h#Mov$HSX(InAoFVa{q~SK8)~!Oj7GE(@7j;D=$40FMSYE3BVXaO#$%Y3OX+Yb- z^^Sq@nw2_*snQt0!!ZGx;F=**l?&6eQi|t?a`GxD+#~H~R`I}{EDv1l82EiazE9>_ z2w1E{x518d=)dYNw4Q2Eg6OMp#7Xa*ur-pgkX}rh#miI&t){*l1b`<~>HMS%@x~tF zYnzyEIZf94B{H+dvuewtR1>+h=HL`L)K4(?JEUM%u^vDcfZ8^c5!^qeSrtnh4j?J{ zLRo$9;=Vbonav<{NOveQ{%+cUod(xlPJyEOE1+lt>beDxKV=VJyUv}XpT-&)uBR$? z$+ZT)G?n(xigGXjO}s0e35rWykJ87>&9JFR>&~L2us4zm3DdK=uSK`cgI34WFWoG1 zm@6zxhh8Amfs+oM3T0Y)h1Y@Plv3Dw2QQMy(*)vV#bThaM$oGy0qjG)(bl?A0efKx z(UX$C1;sw)O=J_2fi?i2=Q-G$Ts8*Je)@|Jn#y$EO9;V^GT|XT-KJjcd?6ZXCK+bW z!q@`rJciV6-w-98k8nAGRYJBLD8?<{$SrW$w!gxVPt!VK0~?+W_?WDVMuy2^L4m1D z2)WW($8u7k1?u@UHU^Pjk09$e8kE#4@%EMP<#I;i3snxew?<<65p_?z9#I3Ki3?k{@av)h9JX;)LjW- zfbg_=qJgGP6k4Og&1oo!(rSz`29X4Gqt165@_@?6os5+zRW1%CbrO&ol zq3R%Ty{VB`L64HYG4>I&%&$H@x5<3`AZXR+k>p#)(kV6LW zC{dr#axoJ{)%L@v;65;ytiIT77O~Qf%7;r-abV$8C$S`lQP{PQ03oBBac`Mq2}-{E78AgW zQbk+vlx95vdrbsL7Jb1pi+&(nuXnrfF~7vCuArgCr{`zW01bS){r_AzY4d>wz==mWV*RBC!5_sJ-fdTm*Wh&F} z%N|&Y8O4>CKeN-ncAEBgS}kk)C;Qk>ryX|c5v=X6V;cuQ^?$PCe^*lNsr>_d2OJHR z78>cW4kAcuxDfh@X<%-WBl zL^V5!&0=xm@P71mhdTnJJhR9(aSF?xQS2RM(=_nfZ<^S_$#l=dN#YvxSd_gmU`Ny@1hFqYKxLQ{DlAR+jMjzu@QOZ;9jt=c9956(~G@1J!Pg-Ug92!Wl0 zcCW@B{N%<&EkMWQv<$yinOPoMrVH);>L=PqBy$sPb%|mK<^jS;L=9XvqUq%t;nFqH zy7&;IVz_1-3=)Ma={yVtm7iYi>{h)xlz1j2M)9vG_^=JEtiT6`Dndqdd`N*iG=-sO zz6V+cvV1gWD%w~!QqO=Q@p=}?MyZ-w3YEj7sBqC&FPUbXK@Ur@f1;aKh@r|=x{)f= zr|0NP7@^p7j;d8ddEx^l*PJliS_U_Zn7;f_-2EQcLOj4Hs0HtezW^e2h&^*rzz5F!le3;Bk@R^9CA z)Ij{$jwZTIk3}2@*vg7`nTYAb-VY6X7pi@>OTuy?O-xTR!5jyKm@)(8qx7hY3lNbS zD&RY`5sT`-yrSUqj8B(-#_P5X!rAtHZX6m~n6RCO!Q5~FWuy7vmpFN84h2$9>L&Bk z*@@8C#5Rz9+A-JX2Etn#jI<)be{u?H+JP>b76uGK<6N4Jk_V1`(EZ50KL@hZO{Q{| zMV+CLVO2`oK!Bcu$$P*6{a=$dZRgJZ(rFpEpYSeNSHkHaK3{@QGHxsebYln~ey@#Y zZ**qeD@i?MAuJycIeKbTi#k1VCqr#;+5i}d92)u7pIS%ApzXDr7FXx`u|Ja{k?+m^ zljdL?uinJ&@%g+C>4G_tbMefhWjz{GwW^RzXy{;rA3w?^7Xd?#5a5a2EP1v z7XBo&58=(Cdz`S7z(z#!sP3h4S`% z@TtJEY0Sb;)SqvwXRKYLDHK7!DM|T2nI!PD9iFYAxz|(`?`f8Kea~%o(ECy%NJPhu z-yq~r{A~f@x?v7QpQT5{N)gmuFmoT*D4~G1^Q(31jYB*4u^?ntn#>m;Fw;|y1Tx7N zRMbg8eE7;_5Zs+Zq%tnF?dKBZ%dl1fZUUz$i59>SMFfl~38hh-H!cDTUK*xR+`3BK zOuWdaH)k+<^a}z33K_f|JgT)3P7X-{Z78h~Rkaqj1Sx{dZ?H^K2Wn)v$OUN)r@eG} z$c=p;&6cR@VaTpfR#XHKPzwMhkyj>H-y0qweBiZzkdRFLbCY>=wvubC!xf3Uh> z@Ir96m;oZn8zIZ0_bqI{qWtU9WM^TGN(S`@*#O)SgJLD3$tG~~8kV_;;?ZJiL+yje zj z+*%M8kNixWW`m`-m^=<_-M-c717!H%#DQ$6T2h6a1}d&+E)~Uoo-!87Vp@&tOHMd9 zfUzXqt~%fy&HdOU0179H)@e-vWOX|td!(WB<{wNxtwg&(N|Jqb4G7ulZW^`*zYrxK zq6n{R!ab-aqAluQS{Ih;xQRyr->2hRDXxv{LTvo-y3zhpy(N*8s#&QX9D$ba^ru>E zr6+DG8WN}Iurs_sz$R0P?IPPHVn(6dWUNTDu^0c{6c!hM9cy=i2ampHGOgKopgA0M z{baIAL8;^?{p&5M%SFLeGnwJ+IGhPV4;ttAX*k3S4$`{~QQ(v=H}t#wMHDp!Gl7=+ z46;r+STc|PP3109_kLr@(pc>=3H%jReJ?uNt;&Z;6?w2K9~Kwy7+9@I_5;%|N{O-| zPPv2alG^tB+((UTgS%VAKSk~|nYNUtVIb!;qkmIucUdQqm8jyfvfOk<4BzKgys3iq z!t}Sxoos`#gkoC0v?3x?fgJW3V;%{uOI5|2l!Q6+#DOH|-29$njIYPl_O)z&tgE;i z7Vhq2%ZRFZR0W2u0SKqx3&8cfAt&aIR0oSrI(ord{Qq?9AOWi->YdV8X{vHF=M)HI zH&Es`^)m3%=l{F!r99W^Egv-i%)m|4pF^0B`J;_Q>qvOUJjF2Q4CzmX0quf8hjv7} zpdUwh4|z|=U>AKaj%D%t?Z20cEUEX|NVy>V?yxVrnZ|e2 zrs240oa@=sj=X-@#5x^lAE=`DxwY_h*3k?RbQ*SD_v~BG$=DV z#V$-C{J_ymAUxIf^>V<>W97m}r!c}0Rb{J)nmM@(5yWSaVBJ5HHjT(=@0i982okY6 z>7f6RD-KR;k`9v*649;V9U4snE7{@B-mBkdFB^sDn6%_zBC9vFcq_S7!fzX4nAQi^ z4isr2QbQ0JdS^8U$PaVl_kF&&pjp6wO7%>Q$SO<{NMpptyWdeyR{JDdh+JbX$2K5U zCGC)PoTG+=bD~2oAj7MdC-UAWwMLj&n*=c16cwf#Vn!?|2e0#jM(=8DGLw8bsQiztN9F>|d7JAnJjTk~ka5acC%4zT_KQ3o0Jh0>>8&@>-%2paQqq4Fo z1Ni4@H>lDQk_PS3_tr>S%i^&H@mI9V$qDv%@ zq+M$`D%0xg#4}F5H<2F>mzK?ZNSA;!5n&SG1y#Ngd|oXGHe0#oN35~ zK#(UtsF62V2{uIIAju6(b3*z&NcA*96_2>OIrF*gVDv6|eQ~PL zo!GWsck=^wb4FY!Vn5R+tSZ|Pw%2RSVA=b+)u-aO9e5=0vrGJ-%71-H>3hj~$0t<#wD-qO zyAoDx=TQP8OTdhp3^Csned;eg?a(fH}CX^Yq^n|Mr;xX}ObOV@E5erlvMP z?i7n85Wqt~EGQ#xL)r5><%jIN{Bx<_IevF;mHimMISxzL#O7mOX6-O#qML;ojU{Wc zjCDHHZ+;{_3;B8=YsxizElEygmxUCa z-^h54%qHs~Qxw%p5bWZkUR$Y3qz|!!b_*t#x8V67;Dzbe~B#UBi6Cw!zf^YwS^rpw`*sJeRixRnU;6U;m)Dy~uykIcwzq0Kq}9b=1EfnL(HL(_s*BJb?@l5{Oq2x@n3s&1kV zT)M0hGge5cb;R`B2tzi$kS6R{n3RGP+#A z#H~du<>Y9wH8bn*bD_UY8yy5h7%?Pf`#7ed?JvB7Gf0J9o|%LYsnOLq;GcWJ@K+`@ zRUt0K$g>W9F29-@o|<5TQ3t5sDRcOULpik-hF`LwIsdBQK%Jdq@i0NxG&RnMNt-J| z=oKsPihY)zZfaGO_04<}Fl}nd8fY%plWTs5e%}1uI41k9GRC_&h_Bmj>Pv&7?@e{| zEv)6qI+XY8SId(wWNknM`$$BJ)u=%x7GAlX?!W*pV=8c|QOj%`6b%2GDwwapdD zanSeO_&uv|J2}-k&X(;U5E+khD<^bw7pQNKg5`GOf(E8T@E!Gd|>f(CY` zq$CQ&Y=;*X&10X;_r0J4b`d~({zjLH&B5in?g3dX_~O}ZLEgaQPc8f3l%jt;Yd?W z^&^;X1$H4m z69r>F#jRPcJE{ajOMU%My#^DY6boRJKmKj`zvwE|zs+}4d#(K_km(@pCEZm9@WOkM zN#f(tcI(h6VoYv4(L;HhT2V&UuY zcag60vn))v8O9baug)Dh!RA!LWAiTcd|aWMU)KKEG35$lAj?~?80;<6=E2tikj=-d zcu{C11C^tpNh4${#-iS?8w)X)la-Oi$knB%ip~~f3 zJ}U6L2=0~|Hr%LvOUS45tCn&>^bFr~<1J5hWu zY2odYu2|EcRJE+m*YyNtc)@-0A(vI~wRNp5(|3vWv7@+oXBl}ktk|}3ah4ce7ixbW ziD`+UoY(U=fsP67a@4-C6;7~Ft{AXUPJE(*_`>oxXqeczAU%70&qQ&()=$s9ovlfp zQlc;D&c|WhO8(Q6U#+!$zYbXktB0EBjk?Ju2eb>iG z=Ku9MP_bBQxcbF!^TU4~@B^C@7Z^o+EO7Kyn7sW@Uv1~hy_(t%$>nryQj<|QaBRMf zI}-u7V+5T~W~+NFiO3J4Ly4f92@qF&(*rqy z$%(C^`mr^70kxvbgAc}iI}iM}(9X=&DZUh2SSEk$f(tcnS0`RHz0J#PZ4LOB{PB!4 zt>AR^R5@0oAnWqg zmYNK%@a-jDFXh8{FkkqXC|HTh5~Mh~r7XZg8uPM1mWa7-%~#hol15yKJ3C4AaI?_F zf+5s%vJ5(p1I3Mk24xy@WBb-B6ZK-SxKEUO3Ko`tq|BWiYIaiJL6VNNcjoSACoSga zDwv9@!Uwx?BLi~g(M{ZsbsrnP-)J(jX%uRK_(<`hBI|t^(m3eJv<>r7#xo&v!9d=S za)MipE|Rs zqMXg7%Q}CInp;ABJ#p5Vubx>jIwNE*e3~k;e zV!Gv_EJP)XdAmkiVEe11H+Xr`peHidE&6A1>B&x#^TvVc14Po_#T6VK5xk^Z zWa;_`Nv%ma=g;?}-97X83u(9X-1qs0;@@oJMhr)&V(QIqFK8vzasSo8XEg={_i>^N zpJ&u~bW5RDH`4vCsn(mHt443nbR)>?6`3y#ig8np4RuoUSJh0{P^TPdcmM*cHLODg zCb#y--3dPKfmnHZ7mCZv=qjl=kr%=3f+ZROY z0~rmam%J4L#{M}oJ^9`A?+gtb&BV?^i++VqgHk8ph}2btldwPBI^a_SKo+b5-_vXm zMfxuY51nHSN*xQJv;9Gt_IgX!fjj=Whfp*6@x8y7*m`thMKZc!QDmNePe;q$*SaWI zy#Y9&$5rM5L~ds7s1zcjdw>jjlnL_tGnXj%ahFUf^W!x$=vC;q??S(!eZarJ6D-E> zq+bW5xBY+SJ9?iYGW)N~v9DGjj&0kKY{`R|QUV zC!2Q6n&k@|&y+~vJ1s5dv$PGk9R)sQjiyHG>$jl9mw}Yn z{y9t%cN$4i{lNvY?leX!DaWn#kQCe8_|SkHukm3q8yxy&h_wB0S8QZNeB@;8mk#Ir z;3;9Y%Y>3y8kL9{>l!}$l``w@XHir3zo5S*LohK&29O$l9}HNA0~ZvlAQr#Gn!DRM+z#VccTWs`r zb=%k?wNM1d)VWPfoVy)`MLOHezLawXkR6><9o0@pNLN=X?&es(2vP!`Sd3@Bfs4g> z`gaCaL&-mLR^46*<_Y81c5G@c-qewMqfHt0#=YBetTLY?XPSH`IAW^c@aUT<2F2TI z%gHIN{3Otx;l@7@==U??@ezyk={I2=oFeZ{Z`+UiTdh#et+X`}HY%MIQf=nIO6#IGYhu^2nQ$PfJcB=b2+?^e9O!wWS5`0D{AKI$GQZ(V2CIiWN9=Ut zz@n0=?l<~++BWfD!EPn#v=LhT{o)u)=+;q1BbS*Q)b+a`lr0Dcit=mlon!6&lkY6( zlbI{w+A>0A`tp28)oj9ujy$WbzR&(x_ouCz}UVQp_xz@n-(aFqDo3|j#`B3T^3DpjSz`(Nb-uSN9E>9Fo@Yq`>YwW zPN!l&Qet4Z_`IuzDULG=%M@RJoYOjqE6g9)k}D0CW8+0}WnJkG^0UPCT0aN-$*LWHK)E0RhGgO*q5>F^SVM^I0L zWHCz$qMjaN(Ups^eFx^Wj(Hb$A%a*=viqi@;gV{y$8vak`Vks~lnI6~J%3KV^=+N7 zn%EN|Eaj^z3yues3Z|9_U0LlwjW+qFy0QGsY7}-I zTTvRp-h=s7Y(Z|$`DK-hCe&PD`y?UGqdC8?EXu;%qeN>%@vPaA%qz}jxcsd$GK+7G zmV7qp-BP_G+?;LMMElJnG5Yqneu+c24GX)6p8$$|YX}BWJO=w00aU2&D6%o3&8*OG zRZw&1e`%Y_ym8NPvv>=uj{F&>mhwAqne^Qvqaf`{ypW+O>2FdWn&xE7&@MSo_5^sH zB-e6Pj_|JLD{x3VRi~adUi(yW2gm=)YYi>0xj@D&t57)PL6o6!wSNd(g}@049Wj|% zC`w)P)^ns!s__^-T4jQSB9`HngE=7qN|0fn9>Gak(CrxmMYdCm5$wv2n6>==L(^Ic zmauE+*giI+8G43SE{_L*#uDA$7frS7;pM6(yH=qxtH6#BpsKI!5;8Ods zVuSR3CE0@@hWm`_YK^<=lJPO|#=u_AwO78ynXLsV?e?IRnZv57nMHW75^EPIP-mew zzM(4kd{WbXvEcSFRu;(ae;r0>RGnfC?3g3Y+}^L}HfW_Ak7^sO)3?-ha!Su~K6?rL zx0bK`k%HUTxpT4?zO8c={9jkjf{fY;ywL$cOhn#FL_o!rC$-11==eO zIhX2duj?v8$pTl=rHJ-ST`D*g~Hnn(UOq37yx)?z0hvBvN+Y)YaiCU zh4e1P=1Y^xM4DD?ngS>9+D*kf?BE}dHmA%K(so7wW?q6JI}>E+cKDm64eIk6Tm@Kp zY-^8xnrCL}HwT+lRIvu-jnQMAzL zq~2i{Nsq*1@Le<~#P7YAU$u47(5yBvN^Ts-`LggWkOmYs6x{Of}r-R zBe*Yt%}OrOwqz)yR4gv&bK_GUDPR)t&RUC?4EhUAP(Q4l(70Q;)XTQeieiDn1<)BK z7-0W`H$4(1X?6bah98!AR3ZBC`sVF+Nf>s-z zL?Q!Fj$>`^U}mu}Jvl=r2nK&HNB%oPKi6-%Xu;(5>iD=tx}ptzr9v485W z@*^jCS+B0{Z%B?ruawHyF0f^%-#@+TinjxG7L=F_iz}b`h?{7>&)8qH=IBw*QXj`^ zxDkp0lKq6b=$~7S+~bGy+D}uuqO~f{3{BU>z4A_KQTf^wc2)p4z%< z&WfKWmr6Fi7TZd`7H#cO|10@6LM~N4V03a(K#(2!`M;667+(&Vnr0=gP+EXZ^IV<(%>7 zm~ou9w`19)ci)vER4WO;n1Hoh3YEwmq==|C6L?Bc5GWNf@bp_l&xOxz(fcpl5XuV^w!=YdL1rl?FuONIn567z!OGkhmCRa zkhKD;hN6_LrLK`KgX-E$9^(rE5v0f zk7(rDqJ~y0bH_p*q((2Lvrq)UX~am{kYsVc5D%5etFSzfSQ7blr0=zH{;;w`u|s55 z&=h)RAsaVl2%Tqav)y{r#LmZrn?`+OC}u!XpNDWwD4zGqawxy2(?0^o*t~Zs_BzN4 zqorpX+$BZQCe)vK`ZL1*wM+=1z(pKU(`r8-RfElYl5fUw`}7yuV{FL!N-XIaUQ7#$ z@iPG~Jscu(PdBVw(yCpmNc~kFez(={?8e-wz~@mLNh)rA1n^M)|Iu`g(UG;$wvKJ1 zW7|$TM#V$GP?0bI!l|wa3`I#)~!AT+d8~zpR`GClnuw zc2d(TnChO=)^^w@4yAi6JyX@wuzvJix=OjAm)Ro^xT^xi)9C^Jc=XK8c26J17)}~$ zXn`!eQO?S>`486P?)lW0Ui{5BQSTb@WYx>jnzN(@D8S3v1I@0A8* z9NnyG`+h$!jwA8;G_XPN$A*PA0^|M1UGfrkZE!g zJ+P@a_U})%sbC^Eir#La+Kw=hl^>Eqk3i4|suEyBL3wDr*UVM}E+y(t8MqJj{@fZ~ zRO4xQptsctFVU)kkpo>2u2v41jrKhQF5JYgN!}VrYy)d5T^t2&1n04&@6qAjah`n1 z=o$`%=F<@t(U0Kv@>3BSteYw(rRjw)Oq6^fNrGXM;KaTz zWc6-_^%n@nhzd5JPy^)`e4{rhI(VtV6yqEo@}OkUYvE`w?Uov)>>838vF%%vs%w>9 z51bO9bl~XPX=1oT%s7gINHc(F9bPG;Y%v8nr6cIxTi_)aG|EEL#rDXFs2CONVot#W zen#5+G9#22%+hXz%%tGBOa19u57-2JI+ae8hTZKlLG70oqDu-JCvoN^(b8L`VbfBxj}XSEbkkrHINarbOS;Wr){Z#s91_x9x3i2=RG%rEh|bXOE2rGr^^f^S`*!wFsqJ*~HyAL^Xs{+?Vg zLlE=T8ral*6S$oj^W@U<2hdocKB$!pFwmT&aAs7V4n%i^Qa7PcE-5$t1#T-kYdRP2 zAf2WyMmco&P?fmGs1ZbWMD_ZmQfZFf3H9Lp#gR3z`#i$!<>&+c`aZ6s z{WqH}ed}S1@#~*V*;fNZBqce9_WO-z-`jhLjBTu~l`F_h^LehoBlPEGuFkq4x9Lly zY!=>Qjf+^^{>BDWYF~0^YBVJ_1sVz|MhV^iahci5$xu_}7}0kbyy@5HH-tsz5@6Z5 z#+l8!}AIS05i8M09=v}NEe4F~s93%E zgz<;qeda$-yXTEP(~8&TOHns)JpSuE6~jJt)>5ZHt6@|*a6^qkeBW!{o8bd@yfa!eB$b;elWu(x(`bI!}4@d@}+ z78%G}V|~;kg2dvwF&co5~bv@M!S{P$T? zU+~Clrm)wZ>o?(#W0qwn*_px)xn37$XHa-=8DE2H=BY{jPVy7|<;j+bXA9S45ik-V9c2i}|ppjbWY)yvqjZoJm{ZePMB`aR1Tl8ZFD;X15+yZ!AcElK4N zEkHZ@O_8H1$D1!#4(fZ4%aosna zm%o4h=K5FIbM)WQiGa9(Va|FfeZc#0=(o*8WxWgj$qIXh?)Fy#a87jKlBM3E@hNiqwid6BQ zLaRfh^z)k<3oKTo&Ief7MK;9yaU5FY39)pVo++6gRBIV=l9MYd>Fsi*nnIxvD$4;A z_(-!yJ3*XN@;x<9&1=2Tta_+C=`jFHvE2-8f#Saj-s)bo_4Z%{ zni({0QF0*lAuqJn1aBjTzXwg*<_89RJNQz6`<8+j+_}v6SZoSw5=wIYsQ9gmd8cn( zX}$yZVU)=1@jlX_`*N!`VKqTId0g&i)qZ142M1nfru#2BJZn93D`J{NKHlRyh;OWx|> z*gaDR(m$x=W-Fb5dwfj5%no})7nszA3iIj)&*;FKB8F&eXoYyzO%~}SeDBB#eALBx z$nUf&xT)a1D^261(g22jwLkoMqg0QWgGudm{3BJPRM2Q*3xN(>RG*8)GgCz0e8h3` zud2%6xKMfug&U;o;{dA`h6zH<+c8AdI_2eM7Eo2uRd1I+nq1jCR{USGTAq5A^^|78 z44bx96#G2`{%ZKO+^j3koqK8#5~ujI>~EczMC?@Lp0T8W#;;CZt9pO<>ZSbqqovYI zWc@!)Ezbg2KiZcB5U7A>*N@|T^FQ9#x3R!4P(Jfj>t?{g?njp@?fF~u?zDIAXX;d6 z{cc0G1AWls-?(yt1GQ~Ejmy$+%F8<@BGoPhiUSJ!Is>mPh{VdW zFG+Zlw}Jy7MaD`zaJWc4ZfuEU>}OnNcX|>c0>{yN)3U#uO&?Igk1=FR*hL z-&)`?yLoM$3QKS(V^aN7pNr+EjF_%f%)r>fc@CD!$k+fdEitL%0{(ZFn-Ro!d@!_B zW`tjy`c1|uQ0)Vl-Vs1;1ddD9DAq7*jQ-Tw)ViAIv6(oeS@vxeGy-H-Bb75Uav&-R zCp3Z+)a6Hi z9wpFI_6QgFxwH#%Oxh*1wG&Mu^JNhQ(ApBb)g7S_WE0{A#{gG1Acsj;Q{6)}@V?L-Udg{B=bv(oy(#| z*x7_yAPx_~TVV*xBwC1>t4uO0OIIxOA&1Vhn5y7QzU;1!p!TN1Bq#sl zT4E}oxufCc38$&OD@d%7t3alrgaU0C9i!48kPJ+Re_7URMPk8ONUkrKe%1qzH4^PC z96D^=<}BOA5 zmSU6aJ@Mxig0V)Fl)9-9i5<} zT-C&JL(ty~YpjyxcOEc2g#JQE^aGarupu?p zstXU2dfY7OrliE`{msJDGmyl)FoiZxVbsP1()&1WNlu*$c&-kTble4_==FojG_>{? z{jk}>w#W8Ib5Wg21!jq~XF5bZFKSnN*T7P9XE|S3(}SQfmM- z;{?>?K*y*@jcP-sXj%75?^Q~U-_7#GH?hFfifga`mAikSYSX?5?fw;7P6087VE^>K z-RL@W9PdYv{7nDPh(f}L$@r2+=s>CsH2nFzQhcvf2$Nm?H8ZN^ux68guS>4z?TUE) zKTp1@A_?Boi-nW`#l64JZ2#)G1cj<^K2?!v-hu3GZ2|W*fx!6FIem$vw)eq^HA*tk zx1XRVi`>U2_+PDSXr-Rmns#qeEXDZVxf36>FAGZw^&<0e*f3*1F2`yMjUW>{Y%!^; zpD`aZdLBOY>(?yomGT-(CH_;&m7fVt6&{u-!$kJ*!LJ_ zc!wK!Kod}i9Oss2hk7A2k0%Y$E3**pd^;u3V@}N>iu}4beZJEbYeyq1a+Kv0htLsX zKqtvrS`h2=th%Yl)Iml~E>4+905cSfz&2PI#IuK?ouXXCnC9Yytv8W#AGJ;2#;DpO zu#6NDzc|QuD=<1BtmOc!MV%X0E})mdeJV2AlS!Knoiu1%u71y5e&%LJVWp&wk|phk zF$yjamV|!I6N^PA({zPzvD7!xY!-ZB#LEjT^qVJ6B=L&>SH7NUq}QI$NUqzD7^Cjir$*}qtgPCIJ>c`Io^B8B#de)FTHN8;QHEN3?^DYQr+ zO5mlU)M5G^r5-c?1Wd+sqFQUXi$gG*t~q{3GJjf6r>&|@PHZ>_cHp$4ML|u+1mt`I zyR1ed@721psqu`83}?B|@t#)egu<|jpoF0BgE%><8-B$Azb3ky%!h2_A8=PeIdw-t zYXlLLD0j*&4KAvJ+(~0t9u~NSI2i*#FsJASv-ZFi%AAFg7oEyU8ha%cf8nn~ISMzr0_4$XWM z&^kfVf9$~et2B&$GZ5zNpS3Nf;s)Fr79&2jeoqADiJ>^GF8FO5^ zDePQQhoD2ete~`9N~O6<+KKAW3So3=?J;T|i%{>Z+}--iR|og!=3@}ou2vj_4fF8> z&qk^Lo+~EHZN(-!EUwqNN5}uz#QX?_M0BG_Kro^1*NZB{uh+G{a|K6o`9-k_8sfkl zqlvj4|2E__C|P1qW9s8X$JlnJWDAM&>dm#)CN-z>^kMnh%Fwn3Uvv5gO5}TTwqvOd zi1GLRb(&-D>iQZT{pGUn^OE=F^|G#ym2{CW!pvzkTrA*(;}_ZMO6$6Z)<>vg%qC&y zTH9Zdu}^=#`Qe{~3VEkLD^Vy%d#^rh&ia3vy#1b6MrFKD_}oMF!p-4zKClcpu@t$L zP9Waa@@AO#>G%0hXU>^{f6X4a4pV|koTfkk92vE!OWR%WF{c-jaq{9A6pg_SxQU?c zQ;jsLD5CnTP}Iklv!bPmrVW8>1$(lJJ%*7vT@3@+6qM0et5G@6RZCg5llx?Ss(sE` z(d*OfN8O0v&!l6&wU^0ha@SK+@3#qS$HV*;%F0uwaFNN#AJ6^%i1;^Y#h*U_8AnRZ zG!moUj_YPxNvaAT(4r$NGysrtOuNMjgHMBUEfhswAw)(mhHo8YQgC%r#wCY(U698g z?IeOx8$1nTN)1s+u*j+~+uvk<|4`3)wZ$=wi^o56Uf5lCApi#>i^S!8DZ2&JV=7AyZqaxCYcSuU7?#fr5T zP9Lo9wvaY92{(LTF$yIyj0s+Jks~fb+!kWMBN2H}2-QX^Vm(4jCDytIA4Lnbso+H! zsiILR9=AX(d>YeoNYZ=L^={FqpE|r zuYS?6_^+XuqknAw0?Ui7C|pGfmhk2?a+c4=k}~XfiR1M7D%-g_jJS$`o8cb-VejDX za{e+MD9I<<6-z7J0_DI+QJgY}xh=y%%xp3=$9+wtL83cE6nQ!|0mxiL9#xU&u+9+4 z*SOy%7%jV!9yo)m5-V8wfTDWfJ2I27SVtQjwy`#kiLP4?Wq|@0map^>JF7DT% zHD*gM(wMead<=J{f!+uYBytA`&7fIriG`f`jBQlEkOS_dhTfL+cT<5LUZC3LrJJHFET6HWzqg|8LpnVj$`9m`@LB(1 zo)uwco7%n~DRuaC8Y4!|hZ?fgFapq-g@Xr?*tQ1%22(D-r`j&B#XP?Q_{C;W_pGh& zKDDlzy#i#9Q+^Op`(~Y09%-1Uq@acejv`i@R!n(G#xp$LEf5J3VIfRZx`lou|v3{tAlw%bfj`q=&@ig)Yj3CQ|#ZiibsHE#gWz{2Uo( z!b=+TGgZ{FSoeor#yxdI>$nuv#IGzvVtPPE zHADJVIQ8Vpp_t5&G98m(N(!3?!S8UViazLu{j4>cnoy+^c808R|6-Ee@A;XU}QO5LZ3kB2+2$a^#4ISG=(mydveEZ2k-2l6SMw$<#oeIh-ctUR9GdMt_aA^qG+}`eatL-Xkcb!4i zcRjiaG^-+q_Q>hcg3qbz$#ec=?SE?9;m>4A7I^Nz)-tF#O?(6+=dWrS)ih8c8S1E^ zQowR!XV-nOt_ zE_`!f6|&?%4`?v%vWI+ds4~)bB(yPUqypY}Q8!XaenWBb=MXmuz3_Vxg?WT~G45ADMl6qV61{4xTOENB`*J!}%{=6_ahDbH z6Z05YF&3+H;zP7EM=ww>%;nv1Z|8uPQkz%HQ|A-2V`r=MVQzNlj2-puv0l!NUwI4{ zKyp%d;k+Sr&&VFUKAQj@GY3br1APlxkhCP3HFnxujsOq1^i+k3ls*{Eb+Mtj?og3X z@B1*`#3*p-(}Pn(VJO=a2gFT}fl;Yx=wqqDb=b%(l@+TLg0H+fA#OD;cJ&+jcXxF*S3@)cS3u;Xuks)p4G4^qh7Dw zL^gQ=9jA1p7{Yrw$Gmz84GIhKoeJuE#Ya@}u(24Ka*+%uDp6oI=JD7Zo4YBLX?rBZ zO4ZQbj*ejyIqX@+Fo%?sI`&CeR z>lf1PQr;@!v-G97<~iURuv`qUYSAVow#P7;=BY=c{xf=$bZt%VbR+17;zNg9^t1Jg zcR%^%My}#r(zob6So9uwGB9iT-DNAG2rd{eD%v)fwMw?GY0Fw=iR~4Ep~*CgiI#0g zV?`h4MK+HAS84Z#3L14#s_cVufOnM&sRn?a+wQn}z$~Dl&-4D*kYMqC#;r|FNLKYjZ^;>*+xRU^T1t;)xB!LUKu6#?YcOo+HN@{ri9m_qVr+0zJ&@#Zqf?2VQ1#Ki}c&q#k7X;~P_ zf1|Kv9iiDqu^Kutm+PeL-NyTj1eLJP4GR~liN?6C664XB=4+QWhR)EJ`V@U64S9-> zx#5$xtJ_N-unB-@YLIXTysdrMWm@rt&MbIystzRt$oq>T_?Kv$^|*-egD!}L*!qTG zfnjG0BXC!6UtDA*Tk`kq$;cR)9EB7L!{jk&DEOdy=A7b@L^@N)sF&j>019?Onmx(Q z>_)&8D@hPT60)-#)28y66%1-)Q*SIRzLDKOPGn!C8@YF%Brn>|+=>E@{b+Qbt(`+O zU61>V;XWOb*gLJz3ydfCce-!xFr|Sq3D}gl%V}HVafVoOK|hpZt5-v|=MI9hrJ z<+n|G*w#_wSz%92p1f5|ka;~#-{Vqj7=Fayg~~Ig{=?=@nNV7q9f;?BV=1o@T;NTGO=U)KJvo_u4zwLh`ECCDqGOeq_Gzw+} zH*gxrxGe}Zf@2X)1GD_|(Hew&EV5grm709RUKN(ntHk5UCH$B6hc&kY!f#D|y~jAR zQmRkEBm_(`)&I18Tb3c?oGF_I$hl{odnV7`M6WbZ2dxObGunbipi2E)(J}(!!G4yz z$Y_%a0$9L_vBQrfH7oe-)^YkVLYtgWb}+#j8DZnOl{|ycqnRnbc%bv~FeG`&ZX;cS z_r0(egQ57Pmnp}A?kjwq;*8W_k?Z{72*CYL<;2{R+6oo8b2Te!Tl)3HtYoNmbP-FK zwGY`N)a{bLZtkuc_{e?iGapdAsUTyigP<3oYgG^g&OuC>lWz@ z{=JN*yX8;k6zL;K%-_$7-2@GB+t+yG7RUl!R!jbd0zU+~;-h~?SE_$hu0z0G8#3Kp8CU+39lb|HXBLo#XQ%#Tx zq3J5UsDc9|elGxW)cA);5tTW%ls0xhXp48qEB69}RIAClyz}t@8q>qV&-kycv1gm& z62+qew`M|&^ylT*rahf16p5K0()TfYxi1cRvfu(T4Ae;;HT^o(kSDc6$tU=D9`X$K zZ=C&+cl5&om&E$63FQR2g&B9_&c@GCDzmiOAXCZV6)~yx<}ercGh27Y+omF?UZVgUHxVaC74n3WHo-$kvh#72XSB)k~9*;t8@C0yckfOLHD85C1kuwCQK9e z7Y@RiHv00%@s$!_*vv0}3cNMeX<`0?ar=6{C&*otwCr)Op4S2c94sj7 zSprNg+k@dbvDwEcvZusWtZiYrx}C4RHgXypIwk4Rp(U3%BTp?*;5;#qr?sjdy{#al z1~Bjn>opWPp+*SqH04OK25dJMpcv>ptUpQA-m-<|nV8&MCP^qea%B`(nI*@Cpp#oQ z(&5qV1rP9ys)A+T$#iblaqyWDb)e{2%8JrRzp}`_yvtQ9Y5KphEze?DyZWQQdH)1M zsCMt-J#+ZoC!GQJ?IeNkcMffD|CX~ru>WSg`>%jQLLMIOcp~9=p69BwhdSHSO|Ji@ zN5!j4`?#uVw)t9Icm4lC+`~n5e|^4v6x$KNer0%r8;S(*?xR&12GP061g>@MJc81R zKmH1T{r!6W^X1R`?RUy`&Si94HLr`ZtqyrvF|TSJhoxChec+0MU8Bi4#~!06Z13j~ zkwDe^c}dk9M9El5WwqIP$$`;%r)0z8ificqA1k5OrYs>2IOGW^V#3e-tK|E89<-p% zA_9gvvw|x#JQ&dMOG2Kk7ir!#o3>KWs zB%U7dXb2Sg4qY{68E|o{6!;{HGX?m*d8#Apzl9iKlrokBnmKwnRGVQ1QP?A;hzQ`J z_zQvZ;B*8=^e|=ZlC4$cl3vM?G~dlad7<~r3&lFb?pe^i{VFyP*R)Bv!)LIlVwamE z=Ed^Km?_1SMV(-zvu5dkav!LKcDSYcOonQja4CCIYFtU~ySZ?bJqRb2y#!6+d)Op$ zc^hM2A({Q$j0!7hL!ogp;$I)9<=O%FVfGjiZb6IFdx$9YAa%CWscc5ujE%5fR!~z> zaHw_QQkW7D`xSXUL_wX *k4hQ_}|L3_sui=$8V6P?#V)4DmcshPD60M zCitQe$}{-R9ax%-mufZ~V@m=E0S$lyad%Z>DA#`-2sJE9Ox@d=)QNou9N=Iq#|;A1 zm2#zSB}|N0W+TUy1^bkL5pUx)PLxSlMO#5ewiH3D|G+%QV1BKpgAxnjkn{>g{;S>u z^RWT1KWti5+wE_YGjp;WkHmf6XrIwIy^rG21dTp<66Ts}IU+kvU7BKL0-SD1>+bhW z^RtJ-lZP`q(J9`=`a>sIT5`SUuMrG0O5|V9t^tf#7s#q<-O>wO*eW@FXkN~6HC)iL z&E4<|oqg%W?-dK|erZs%)p#kpS+|~=XutajD;ds(>A%0#smK2>ZjEt{bgKQKqG zH(WDYST#qt&m_4@gTHFD%8gk4d)b(xE%ZUb_xq%TXT5ut&-b~ugl}Z2*}!mNuNkg75aRc6EJ5pK=5yBOJ7x}*Dpx>S zqNj_pw__#BwcA&JmKh}3FR!Zx0eMo&oa0Q#fMmrD83#v4NHRbOXp|C^2Xcgt=)Djk zbIl6ZPNaTUbc+rS^z^$jRTr8pv7rl~zT2wLd+^}SIJ>=x25xBNvJN8|-W((g2(0#nz!SXMH zy*9i09KOqX3)f@kzBgHOqWYhd$qYNy7Se_RH)NBw_VBNb&S^px>+4(-+i#tY154Xt>;CDzc54U|Lu-5FvPAo^G@Y5**5&?lJ^cUFG_;NJbtSXpf zHMO-{U>+@$K*BgkXqyd+JSgf5=#!s5n0nS2D|8i;cvUXqzQa|Oro1SYZdN2gL6*Dr$5t&Bg!Ek4G$8sZattiheIq}{ME(*HbZ-IDBAqXN= z6x&;J7A0ge1K91A875Yx4t|675GmqA@a5mieEqk(Vm^=ILYg=uP+iLINAX{W^gQj# zaKm$d`=jPiX-2L15kIcs8KpB5{E@e56fkTrue5OgQXPsy9`T#e7}M}&VGb@PJnG-A zQ@cUsH}-sdCthAZWVx2VhpnNy!6+N4I<>M_EC~G3QY*Z(MQQOeX*=EeS`{{zo<{(= zRYFggBl6$e(WA&^@E;7+_IRG<760u^1D;&Pz$eSVCtJfef=4Q%^f3L;L*aMv97?yo z{l^g8p0nq>L$I!|DJVskoWH3ryh(%%!ji!B?~o}yvoU!@YUuY1(j?(xJ5LJ<0W+71 zLs_JPjmn1i!1V6BI7dqKREv$>+JFsnly%GLKi1-;gk8_qU(Y1l-!63(K_0kY+bOS* zWw3098Y}I(zi~Z6M_T-ZI;x&j6_q~d2+-;=UvA=SPPVqTTpr726|07Az9+oeYO*{+ z-p^7Bv;}zfI;9?sIrM#i1pTiU#D60M7xgAfAhr#xNfdsl^V_xa~YzvSBm}DIHNOYQtg4%5(0Z2d;ig(6QU7$ho& zG-Xyupok;Psi^5*p_zs>i~{qhgs(*kI@By{(l}SY4vk8rekX2i8{9Qkw_3|MbdtF; zZR%kRsW3#2$oHihSObd=a~4@1(?#Qb8I1gNIAhx?jaT`0FKWejkGIXK9i#AlL``U% zivH3)dMni)O?leFA2WDalw;*4d4;#E*;Dnb5X2*B8;Ek!LEM=O=gKsyJP@zLjByNT ze;yK0%+O3cxors>M~O8=Fry{w^Y-hFxU}uX{@^aF95^ZBEB`Lc;1vLFiCTjhPwHX# zc2=!2&`pdO?Nw(olW6jmXQk8Cov@W4XtWLghaeIPH<8|E zKjTY2qH2joq4cewnpy@&$39c5bGor1Ywir%_GE&sAz2+u!JF8b!lYj#$g=vrj8mvh zLtueAs4672S=D4t3UZefk^}+4a=?J*s!Cd6vtT}OtB1xyU~ZBg}_C{R7sg9 z#q`ZMi%&)tdjc}tbP%ft&5vspE%%d{IhCtKbbSExH1*0op#Ved335qXec1t+8L-er z6uz};9_=L+N^IP^(Bwl1B-GD~5z~qglfp>h2Re`^~{bUc=rv*?(vUjre`6+ON z=XlcNcwBl5;K~U^+QSu4ruq;@=%%uCfDXMgP2a<42ze?-LjtoX4(_=s+r{{w^>St&-t$9!vP)vN-=}3yI?TI35 zq?_6BGqvzK$8viZw!=%YraCITs6XZBNN(~r1jc_n)k8aTm|!4#a;&g!#BEzhbSQON zD~V&$PUfGnqBjpamd-KK;?Z9qx8+n@lN5jFM(fE7Va82H;d|c5(c}_s%zc!a@WHYdITC!ws%IX%WtLHj$UCgCsDJuP11=j~Nu!tofkSWM`P#T-pFOllI9 z^>pht@J*>?1EpHezd@WhKvtDKRNB8NISxvEt$EV8EXF7tRSDtbV%7wcN|xuYu64R% zNA)>bm&_*;pH_4;>z}#|qZ1FcQckfKsRayEYI;XYRs~OKo>x5=Uas;al_P@W{3IGz zD1rcW2*xk=1W-~1SzBviQN3=Q;StM6_pL$19zpNV#q(FTaFRKLNOXZ4W^CvyxZT`h zOY39fRAkAIRQg#x-eD|(mPwe}SN@Mf4%?eA%ZhvdG9tG$7DF3Lhxv)q(SKZ5dQe-1 zH~Hl?9YS2P)xYUAG!7+A-;=U*xYDz%%Cii(4Eh8DN2tdht2h_tC;zvV+_eQMA|-b1 zh0YGYLoOH%DMsD*e$0|Q^AV-w7XrgO4?s{jpgsF<^cZVm(C9n0gM%X^RRWPPN~0fN zkc;c(`kkq6cJ^325rfZTwIMiS`eohC)lI-LgW-li%imP#zQN@>p}S$4ubV$Oe>?UL zZ#Z`ySfw^f)gLcE!4QtFrIuyu(jrkM&b3MVWDI&I;09eOcR8p(yw>lws(;o zy6<+oKObk^Z}#T9mj29J_POx$F-@-H^8t9YQ!IiXRHY*Se* zxcF_aOA3PsJ%Mi-Cd^tJ%FBIqt>W^|x>GYkGvt3dFW(FvEB{NuWQy5Zd#UN~{qJ;5 z<>Xg$HK0}LR4=BKfJ&2^*CgLdpQXMesHUtRoF(=KJY4_K8}Bk4xn<+sT@*rmzaj7yb4CZf5Le{>`x; zq3E2K%v`q`QB3-TeUvTLCxjYFjm@nTHYzALgNQ9?BMU`{aFT5dO*tK;JLao5$gesB zN1J8XV8$(E$wzIyHlmG$qgkJ5?X4Y9?}y)=+vP%7ec93M>GZM(BMN?WLi$a~*c-SD zh((~$+k|06O)O$@G`~Eo2YItfI6O+kX{M;5@|n{Ai&-`$ipd`_<2n4+2;6|Ur(8-X zN5QZ4C_0vqec;mi!YL3E0@rDq^H}}(B#&DAtHGX4`5eK`D`I*DFfI9k*#}s*?2KNM zhG)5k?yBb_O&r(fjy|`gq~%l8BHaGO-Z(o`&B3g5heJau>tgLTYHqo5WSp@4brr?J zthM0sy8uYxy|g9DCVha}!Lap@NY>b6kby3Z-mddEMGQ%??G54_+d~H4ZnCwH3qFzn zdI{v2zr-UL>j5+Th0o3MZIjr8_V%<|T7wWnEqA`SpLl)sSjeMK%GzM^}p*oyd#8XII{YYOj!rc}mAXH|F zCBY6DC>MNexbYb$lsntqB4g}0GxW(pxnJ~WSL3}sJyz^cFUfRf(4$gT&rWtD@S9jX z+!`*7-Bi3&{fk~U>=~lp8BqL!@g~bp7!B-29;xa`I_aHJ5F9=iU32kV+D!`F+Jf2D z*Bla!o2W6j{~K4KdN##NH4DRYuSbG~JX@n}6FBRoSVMer!hGD^QuaTr642l>DJmMf zb^VcCM$3;H8NWPv>UhOPiQ^p*9{4oRm$B3>%l?_tRXRHxAn!QR#9D&gRKKX;3`-#F zAx|O_V&rHEyJy9Mp7)X}O}vk{EOWvInv{{HSu?rz87)293RD4~69F^fgQi~U>mZgvcYWddi4^l5NF zLL2w98ffT3iEnf9p2JdGG4Lj3LL&>mZBPT^bFc)c7|NpK)BOkIY*(vPy-D~iMcek8 zs_N;GoB6)B3$*mLdO(ZcUZ?m^+tc&>e`_D$Kn;gj3w{cLi}7xbC~+V2X_F@kO0zh? zGusGmO!Sk5!()q8{sz5N zfD^I=YCVAt$vPr&V;Zr)@-x#QskySeB zG@u()uBAyM#3PBFeetPa)w1TGE>*gT>&`P46_^nqnwC5j9xh(JOp7HD+MW1X&dP@)T?Iz0!-GWVQDWLG1AP#drnbnnbTPiEpeb3Nj}w-bIra8{g=QQ})ri1@-58dL>;s*&Fh z>!mK3Zn+KVk-rcKk9slRIUqWEyO-F$pt9Ma;ZQ2VF0|bmrsDYcG@4S~u zHpJmdDvYU#a&~H)F3`L8KQV%(Sgf%9Q1ME;sULNd6X}z+sT$!Ge-6zuNBvF`WI>_I zC@VW}gLIB4tS8*h$MVL4^`AJl(a^G|E@o0iRVaVjnA&Ix}N4(n?XM)t~^J-OlR%^WFaUDdc$D6kJQ(@NGEQ5)?X= zSKvsM$HDaajbUfcEmanM;#PnPzBQ%=^{0r!J275HgRl!itSmUmGzmE^3f= zR?jqr9T>TuX*qpjp?h&-C5KGPL%vTv&<=HL(C8QU=mBuM_qWYvDUXCHxXy6@Xd%r} zB2Woo7L)0yg!7d4cRmv$&7k10)*!k4KSD{p$z?amKUgx~if7d4i~NWSl7{mDc{uCN zt=OIkFj#0;$tlR4#r$OY*_aD}(z=rc8w85!Nao@OMPVK0wcL?t_kk2jWzRi6znrk* zV036nh1<-7Z&6JYuP(SO67vN&Mk|*r7CmMyk+`mfIHr}%rKP9yHF*~O-bkwoD(3PWva9@oci4xr3ivF4ax8A3zXgTRhJPC8%0_WK((aj2O8neWELLuDCCSjT4+ zy`n?0bhJp8=1|xwTe^d)(ey+l+{|#%_Q!CBY2UE%gbkQn!A{t4@Q_o#WZ}L##iNq3 zWGKi)@JsFgG4+n!nYF}gJk|F6UOzv52z0vUexx!^z&B_BeOuV&l! z1f_UU;49>(>{dyi@171zDF*%v1{wKYf3t#Vb)HoKFDQ1%D2U?;c%jZia#y4#uWkOF zybTh(iIdxRJpNt&j(ZpUdTQwa7DK;Ese0{B8hk9Vs+4|-X%XcheO71>Xcv zJRSYte2z-&M+YheBasC(MKpitFPVXQ58GlPXJR~s5rwu@GER;ytuTak4K7Uxl`q2Y zGeF@wEJ2I>8~X~LT;|IZ*{4KJT?B?uE>bq57R0bT6+u56bW8nOfPNJ`eB zctx8@gDW*AdU%A^CT(QT71geDYGBNKQG*5R;aKdp9ch7W&6-nWqY~U#da$T~9nu2u zC7di7Gg{NmjXC1op(R}-7K6e_lt!d2D31upbOOHon&jQ|+#q7$;Ka7}x|+0HMq|hm z2hU^&CX(eR_lL?6M`x@wMbd~#Y1(BYlI|CH;3&a{Q8X4&^BE9mMKRvRflAaOu+2$^ zI&c>t3#<76f;MIM?RZcSem^BaA~xk~p-iyT(ZhG#8l@Kj^%hSN+(FG8WYEFWgL;%` zl4+W(G(?LCt^0%=&qQ2F8Yotf9wexLh9Hpuj41LO7}6Cv#bOs8VP-cxgGDm5TP3UH z*Wg?i2XQ(K5ynF1ky1 z792vt61X%vrwU49<8>uh$oncPf~3C~R8H$6LjU1D22bo=0f<)*KvgXjZkV4#QKUpv zA0KS*Y^uMXBh_XPGAhc-Pl-+qCB4cL4F)0!jJvfzVl~7!qA>PJwaD>Nykue+&-Xwr zewNDZ!UhrELG)8#I?xLC4h~F6lccpMlLjkRQ9GBr9oVFt$g(v+ozvAj$b}cG!J%&< z0&qzk0kR>6nVI0fLVus*un~g%yQ&;0J>F&IT~&L~e@}k!z%^f3r=p?fJ#zk)|892d3kPTUsx&^;^=ip zLP}kxb?35!Ws6a4Cu?MIfmSS{oqjkMHA6_E3OG`JHcj!XY+CELkAoK@)l&ceTC+ng zp8+_PGSr_3^&lF>93E3FuY z4%fo_z^?nKYUilF(Pmysylc0enwzD4WuP~Ep}1yFl&Z&XeG?D9w!hh-k^=9$-xd@6 zGWobCovpi64u<+Q!^;$K?{7B;o1*(l>6}xMnB@+!uL6cL@=O6c#{q05c!_SPiQGE& z>C|X2ocnWf%lbK}yqbo?ELAs4(aR*7oHT?!Rra$K9Yr{n+Mw)Mm(LHOrcC>r>QGF; z7eRS&?7Xt{2qBaYXz7lD0d9lWU-@JZAck3#dUbr9+wPAUGrS+UXgbh#Nx0f(;J14r zxr7=m1(YOqA7tY#IX+1a!b_ZsitrvNOcA3pvm_8nU>lS=$BCHX*(Z4K2S%@@+8%)KyL?N*q`8Qr-GNJrODtO$&nqI#k%j>n&U0A|F% z*Kuae4B5n{{p3_Zn?Nq{?na#M4CEMc)S2T=%-^N~jWf=UE`QR z8$-~*Q(PjbJh4SY5+*!1@xu0oARVFG;*?p3Xi~u{=PB~Ek_1_%M>-Ot&U|f_gKT!A zE>MIT8DSbCYy8){kZLA*)RRUZVPE44R4TdTCclnL2){Sw+a0jKkFEcsb0yhTF2TIT0eWMt}c9*mNZo>Uop4u zyL?Kj{;dD6+Ut8??))4um1tKmGLqcwQ?vYrq)`9l?sYYiDkbpk;BdcveJJ(+a`)%# z{0@`%{Wzcb-LR0?^ElHvIkLz~)#T3tY;!0t`rYDDZq+%S{cX=|wEEkx_tKkB5T|cs+6Y-*c8@k3lzPYk9AE?dbA-n;P%3 zcmeVJO9pH?O`vY~F>W{?hAYrdMOLkSis(TU=%Eee)`)#ZLO=bgsPqz_K^=8KM7n1| zJc~qI3@vFXV1FtTa$dG!(|lAt(m)aPCi5>*#Gvz7|FrftW9dr~K^V=q+G5FzXCe;6 z9BFW2hUa92J`6z>Cuv|mkwNt=f#>j%EK0(+T}&9mE$^zOWgl0OqS#bJgJY!qWcdZg zu%yj9s!$okUFVT3L$2u|bb@o*E~BGC$s?Sy>x3D+AH%^+hki-DUap!Rm@vw=MeLN| zAlP&e78iP!4c-n$Rcwim+Ks&n4uYN(;Z#&hya*FOu?R`67RS-RRovIe@MsYI%2!2^ zyl>~Vu9OP+6+Mv-A{{|HseXEVZ{bI)m#vcP1WhdW6B~=fE}mfD`&VAVU5tO|JjN?^ z7(9slN;FSd!d5l0{0aWm2-JLbev+jaE9M@eIxQ-tQ~?|RfAGj=;i5+3e(`KLU#Ofw zE@I41vtZu;V*wCH{#zKRXd@Ij6cwnpWwp<-y4sB-S_tXplPrUy#QEE2u19qvYQ4P3oSaV{jzOkG+Fc^uvl-O#Qu^!M6D~R$ZtVOt73rx|b3o zw={auUL!bBaYZ1oh41}LOwn}7rl};3>Y6O>8>JxAai=j)e_S*T~&+NpoZcl$5#{x+ z&29LVs6OQ4zv<=a-~M|!?K%E?s&;g&PI2_6QJ{%SW*9M0h~#}d+%!Qc0a_0DV zUa-M%iF_1u%gAkhk{W8Dv zZKa@d^y^)?&MVzzIDeag>@wYfu3^rY#s=t!Oa80<*I^GAmb5PUBkA|>x`MYg4SWbF zlTO0!i)2cESd8pfDt6qU>eCWD0#3Al{h0K ztPJnv_nHd;LNGio0vX*(OmxPTrt?jc2Uznhd_o!mQv;+>_FcdlRVcX(6Lo{Sgxw4> zRbQ)xaF+WRY#OtE|JA-4+~+qSv?@5Q?D&;fPk@Rzd<;*iq6uJNk5Rfr5e@mQrAEY} z;#K3$=HJxLCO2O-ftXp|Wr#2@=f5hZMU>{k+0-mk=L&X@O6wafvux$UsdAT_umnLj zv<~^jk`_h+p@nh{^#*3PTW!kq$3;pJs)hy?&pJ)--_v|X4I#Wpay z2a2;T+%8;36X;zET6Tl$lfdfLmc1uX(kG!QpB*>@Yx+~Q9V8HY?5hNbkRK$|O*ptZ zC(OZof)g@=4Y#}k%_PR} zzO(G`_nbVdsA~WcMkcm(xz^5->=>86R~}A4fOldzIv`a(ie8dxpB7 zua3W;oId-I)mcpM*Xjm5cGmYy-aIL#i56ZejWke=9C*t z{}&bVKL_QsC}_RgD9sIz{aCG^q@L@xe=l%JR~_^Zd+bu>n-h-@8Gj|S>9a16*7f|`ffjT=&Y{Zn{d;85;s{9F%jqJwvHP|d1?-v$^0gNS zNQ!I{+a=)QTB-7*M2IbtlSDRHcsj%)vf;B~Pewk)|tT^f1jQBXI-}_@+9T!DCD1qn*zosqAX>H*H==h z_$Zo=>y2s`?Bkdvg^RQs#@0-!j)NDL6>YP*%D?n`K^+S36|@PpAvS{i#-MUkTJVaFYTPi~ zY)>e#6P2BNtPI;5(J%;4?Jt{YKXhIRjx)4$>xCqq zs0d9{pzMw7E0s{w;Un0irf*eojD#YKLbG6KOr!sWVo&c#@n0TQU`!>hV92ewdK>Dt z2!pmF5B$02WO)U4v!-n60hDjj1)a6JLX7MXD%G_h&bk(X&iirOPjSKXrdK_@LO6!)EY`SPqy|nWvG0F0OX|?~8 zJMnEp_^kV4ukn;e{2m2$M|L&>*K7R1pCB$A9cQHud$j8|TogVqqf0&M&HtR=U#yaS zYR~z7(T@Dj@2I!YugLyo(vSSVe=!PX|Bnpv_ASL4SQmJmqx$0!gj;#-|Mm7Q1o0<~ z+8h?w+gX3uY|u5vFIThf$6WGWDVA&&m}n!(bgIR#P5e z@1Hw|F?G$d%rkm8$S?w0rupbFueh=zf+^F9z9Xlum?s4nRu6JIg%cFRCg#3753KvL$Rf~8S;k#+aDvlo|3pVm~;*%?E zo71c0!Vxkfl>N%dv6>-$E@w~+w&LgFiN37F*6DJ#q`V8Do?W+B3%gb#Av8MbSKQ)+ zb3#zO`JuZu8CBS;06X&^k|xsovheDd&yqaNRvK{geLPD#5}zYKhKkWJVXg+bH#*p~ zuE1Ik^|)oY9|NpB=?tGQm=s91CdhyJY?9N7qIn`9T%$iW%oTnR6IwFlCu>EBfYCAw zadIJfX;#N~8a%xpU&s2r@Bc*W z6`tJSxM4saik2SXz@+K}@a6Hn>{Tq{GWDh0h+WdaaE2c^;yuOAz{2MYTs{8nZF{Ic zHq@3MWL|~*kz4?ccQB7WA~O|e_+w};!ki$Q@eKQO_TGlO%TXn^P3qTPuds5ZOD**g zOmvr8_F@nC#W|cUqE3 zjY>hvW#7)u>q~oxuc~D&;d;iGbdr5rV+sbcgK4a2TO<3ZUUyx0f4zG5JTJFw+ab6fI7Jm*pN1gyt%uWK$Yj4i$SK6xIwXu zXvsY{u7$7hd_cZ+2?5isqna(bymvD>r!MfJeK*2OQJG2G*>s`=L0LB7zbVza_SR_tJPXF}@ zJeWbsqYs%=l8bRliplkX7kFZ${k_UOTaJ!asEu`fX-4fY7mfbPBS`%qaieTUkSv4z z&!c)yts0p1?{b7aGO-PQQbR)93Wt}!$&ch00~3HWhV04!_wtmDfsxnFJFxK0*l+t$?A8|4O0UHB{BLPYF;vEFdq)B_wv>7Y&O&UTQlcK?_+J#(ynl;J`_*R98VN? zT`UMOuui^PS=cDN96IVCHfr=;l$ex=Ht2|kEC~cZK^gOuGFl4HCRko`rJe#^%Z@3L z?$oVxb@S}}ZOW6qtFfbDrcLumrm_l8(s_nGi+SYN+sTz;4>OdxjMuG_=Z}UYrVeqr z4y2$Q;?CEyEB>AvJ2U>hpmMthV$J+XI4Q|GO`t&|Hd2Z}qIj>K)OT9>@HOoZMDozK z7w`0+ozdJ0lDkrs{ZQl1Kevm!1;^L<7>>*Wb+c$FyboNNEtnj4pPxjvC%{FK0=Vzw zg06yxT%B}9v<^|E8C|P;dHXpCqr=+I#ATFSL=!yoW{dAGGN|o0?c@r$*zG~1>sK5k zhUdo`HlURY1l6$=>^SwPziM{`S@v9l!Ebmfl$4%eQ`%ePF4>oZmC#}c&NbX zpeM(?d|wjl0Oj0+fPX&4@%U627k{hzL}q*L@DLtol;%M)IZSzOY{hTH0|4aGdAceJW`uSPdrdS;v}9C)Q(Y;=E? z^Z`GB{Og)qE-9|rQ(42UpYmB(-ThC+f2)n@Hv|B{EE{KZ%gW#Wvi{yz8JACcEw)=X zya$uT?Vvf(JjgDDCw#M>J1KjhY>BB-Dx@#_L|^|d%WP-7|1;k2lC3cEV5i&l*j0Vu zy?drMoHOmq$#PsaE%ulabS%$%m{-2}HCYDWd72rIg%!7$D{Pwv@K4;&2!56ShphcP zkre#G_BYffA0zDmx%`d465A$q$rd;+4Lt%%ub$kb%)*NXC63520X55pptdK_Lkg6lYFin4u&UdtOo6mH?d? zouPg&Kx31jj~G@OX8-V0mK0AAYryX3dCYn%@B*D?>zQ|v38a^P5egq>L3pEq52@p>TPSh(ZR>!A>-NB&PaLj|qIN*Ws zMJwd5Ol(x)w$l11vLQI=OJF{t2+f#ek0y74uSgFoTpsWaNMqeyzBg#}Dh06hqdGtZbV|ZOvhYqD3%BPiaj2*&|^k5#(0W! zP?EDWR)r@w(O%_oTpcpPSa~paseeB6Z?34-tt&w zgk8D6^V3;r|IEr^({O53bS9WiWPzZ8zV2Co%FFe9qNfJVl>B){1AH)+P=>mEn4BOW z9;|pQ?qUK>Q9^-cM^ps|$uX(ChwwjXNb@t;?>zgm)0lh@w0B-ZtakCJ8ztmS zUNTgQ{n!wcx2gHSd0m+p>A75Oq0R=Ev;cb2DON?8#5A=m!lvL+oU>TtHUc69t09>g zWl1LydQ#R}Rn%J*)SzQ7vYZ4qVVzadJ{r)SlIiNRV z$DVHDY(;AuLci-cLAljhJ^C#}9Lv*rRXyXL^rJy@F8WS+=M;3Gl+joWESV{ScuN~I zUkIIDw$;<@K8Y~Bzd>ApxLSeS0pX3qsAQ$h4RwP}5>(k5dHa{ogY zPmmYe^_g7?zINX_6CQAfa;!mq2s`5@c)`qXFER#z~|di^9+dh2{9HS|9bHM zoY85q`>^kMkQMy;MCQB8l$>^T>t5X$JMUb0p=@-EI(Lb6)^2LNm?xFhoN2lnf@;{$ zV2p7IYrCyJo;(#Z|LXJqsQL3TY2N+WRHLnY5dPni)bA$P-eUkyC91=%W*VHa2ik22 zAjBvr-?d9PSTcsik7A;%zL<^5Cvqv_)-SY3_S3tC5&_CzNCFq_^KqAlQVHiQtQ)P5 z3}(Jlt|=Bpo&t@tJ19Xz1+rc{n_o$uk|uF1%m;viRTPsvB79Sxk)H;e6Kg{tF&L(o7hBwI4g?6FNw@bGb8DaIFG zaH%CQ(n;$_LBz8GHhdGW*^{Fq_tC`XInYAGJCTsb^csC^ zg>_A5D-rUixeh7_g`vPqMx2yF|q-IeTc) zq2A*XJAQh)g;6#wAt)F>JC)@eqOa0&v9}>*h=3`S&QCCFr1O@j40S{ zc?+3!`I72`)hTnNT`Dv}#3G~{cp7S)ylAG_td$aUx%6u!O_>rTg}2zOaqaY2w~!Pu zG2Lvo?kI8tbut_9!;oNxaC%n3Z%*ju8g zXsy>Pn-!yd9^Q1`0eDh!n=CO+vvITEB&{~LHC_rJ;BRA_#5r}-w(gyGv->-5BaJ{y$h&w6V2yId3GyG0I9WB3?F{!qHv(^uo5pMP}!pBG}@ ze^sd&2wkn$brx};?|0?=ZKf&jd3`_je_Zze0Am@`Y~=h`VAd^Q%QyH0#4D}p=v6uL z-fg$866i4YKA^&QBD#^zY3(YBTpZ zDBVPdEjR<-N4ROlM6F*KV@w7_us~jZtIKm3>(NA{wC2PuE`j0-bUb0RUYJYzM^%|FzNWI82=YQ}5zHf3|ezDL(UwCk7(w;$jRq_&Wkf$IqNoCSZcumh`lpCJF-%8F@r-w3! zpGNM=$8NHv%yi2d`$exvQjc7!VlfROxJj1gj2b0bg z^K|}jQ>#||G~f>UYshzf{ZqmFR|#j3Bl|TOU~m-oET6a}iZsf=v$b6}#4E3xE5_mgbz>tg!iw@olXEE^lRxH!?4}-F((7o_H=l zn=YSrishH%w&zb%3wC2n-rG`b6?Hc!?B?`lZLOf!t1LT)&q&>YrCpB;U+)jtU!T_l zOg{HnTZ_A+9f!MiIsb1Tbs=mW;WvmR_ayg!Yw>^kk03tvt(_e+5!rcC>`a%L_w}uV zcHAb??dis}*gF02T<9b_oQ4xL!7jMdjVa-T3tv|l`*?qH|T1(w2^ zTwd%@%;C8+H#7LjQSwcr%klQbLWzx&9BYkOX1V@`y5o^*B`h&UuEwQ5!} z^*{SN>e@P?5<*Wxx&GprekkKytFML#`5Ap>ZMi-G`L99h%KX-i_1KPvBk94^qa*ck zM(#$-RGOV6D|&fCBdIO+iIh5qVU*imjzCRPEok7&6XOx;tYoG~4Lx^J81fA+%Rl3} z!3Of8q%`sq8uJS_V2UYXJC%{+6Y^D3Cb;QZQwk*7Q@U%(4kN!qB}Zu0k)sbp2TSW` zNXhGIvQ_;c;fezB6aDA1v>- zuKzQux_jfJJrLYWs0$rs=JMD(CCBvFqf-+yX~O7yGwZ z9H{e$+rV-ha+tBo0ejo2QT}PxADi>nZ!62~A>Epwuie-Augf}~y7SZTh0Fi6(`oP; zN@OAO|H$5sT=-~BhyZejx0AnB`tQa1Ufa+DH=|_c(01EZU%$?Q3prgEa4QWHD(({( z(HWg@S6Oil6F8c-4F#iLZ;GD~D4t&*&q*|Q-}%{Tyvhd^h3&c5vH!hdK{|PK*pgHQ zkm3FoyTA(loWMYdsz(|pwIpw+M}fy>!qAd<67QuEf>4RnOGHuTft3nFsKY1_K$VuM z%&SSxMocg0wTLwC8H87hLL5-|Tq#2A+bxCJrO`i*2_{ZNr zkS}Wke5F#>hEP_46G)_Vf&)39sslsM)6#AY;EFZVL{bb8n3)-#`9mnD@K!teo~plL z@=`)#r%4W=d1M-s_<(eSq3A0qt5ONo#xXV3Qs^3TrX2wDrMS!+C18VtlbB_2P&b_8 z&!dfF_IF7Ca~Ue8h@BC>R>$$^_V5!{`5PN2gxY{#ubZS8yy~bXL$Cr2^UG#<8g?Ew z5l=8x2B;>5Tg?UV5eUb#`*}Vv|Jf;#b=sk@$cHpnh*@HMxt+`mQ$r+)dNcnA+Y4M9xnao__zHF&6c7@kR7ybl zjXR$--X&j=3Ydb`9j`)h^(aV zv@f$lQ8xn*=XRbTMG$F0S>fwY{(}&!cE^U0wjeue)Fp2GKG=yB6}bwlJ1P=ysGur1hqZaA#fcwkUO&{2?D=ZSD{$o3#rwPZq$xST5dYObEpe@n_y$ge# zS!3(G1}&eBi2lCb#we25Mlc|!-gokw>r>Uk*7D&t;(AdFzfwMvl{X?))aiY2yMS(0x_dgnEx0T7ucf|i%jPJ-NXxOF0o(%T%dQcy>GE{#|G(=z54`nw|9tNG z+S%DDB>Fg^61a_=@qKFek4ye__T3!?KL`9BSjL`C|GpD6`mQ!U+S(R7Hx0x(BJ!Fo zJh=ILM}1^vU+$LrcC>F=DS88zfY#pAO`mI+Hpg-2D~8y$4VOSh|HGgDH*Y;(tBU@w z<)yjx1*7_JP5Po3FE<0o?x1TrSIX|ErjNbVT=VY?ffoV%Ze!+GHRKFeKq6GHk{6C^r3idGvXT3gCItYVQ!;N8k(NFj(c+n;>~ICUtZl zF9;2=*GQwp+G=x;ryjq^rImi;WYg{%wJoc`Mu~W+GE*u3ad8;EQYZlj>l=%P0g#*U z28pTn(M89{zugz6_hsYWP_;w!7ye~^!rwu!$?!Ye6II@zT?CGA?ZO#kz%z8G%#o=fi4P_D}|K9SN-`5)~vIDM>I8FV!?ToDMT zPM)VtSc^YZlDG1Ztpn&J+S9r{Od-_*<-xe3yq8Q8S>i!@HtsbfMpS+-oS8B5qhVu~ zK{laSHSq~6ylO?KyX9bPSc*)lwhCC*WXuT)jp=}P>z`MdU-T&C6&0;dh+eK>hZpKp zFqoKvcIRMYu4Q1ZHp8 zCPtB%#B~%DeAkD-oDd^2g&v;R+7zmcR$Zc*OH`usfeTHDrP&%9o$AJ=y`-WF(pz2W@+94nJM>J=SR6 z4md|tlsF0-HRIlK&ZNM3+nO$>%&mSTN~uX$Fb)@vVg}h@6}xxdNWH#{`|MIkGV8^m z+kl~Gacr{rseUC+0S7d^Q%l|_x`+5iT-@qRlD0Xfx}rz6eb0;KFG16IIB_mI4M*%V zJK8hxdb}(b^mJY~t<2-UgB|X&nJ@%5qUcld3jzE=D^P;?`n_bj)5 zkM=Oo#Fs`Bv7dc@8s`5OHwCtO?M+KkuT*sz;;<>CjgZuOs_KJ7koQ6mSOI?#k_bpE z7DyHe>nJioIe=UVj(Q;!cOp*I02j~WY?6fAGkF=-@rD<$NWkc~=<^`PPXw2dKTR4! z+shc7UzN6$fJ{_Fod^y{usiVcF6373qSDT06bZ_VEd7fsrhWWr1NxzMi)E9YGc27< z^l)q$z);;HcvJfE{s8q$0F&YgWg8Yb=pQjq*;R}mqcx!xSNmUtK+mK90j69X4EPc_lA4o`%7U7X=*#SHd(LbGysB2rb zLNITyn@N!x#QiriF2q{rMm%C$X+1T#xf>hWCsfx$30)rPC zrTOjkS2%fvy`%v-gfNL!ZCyN!{gV-yOj2KyoWjHe-x)JxbB9u3@&Rpu?ylr-S;B%n zZ0T`Hp4u-hW&>k%)2w5H!dX9RN~OxCW>G2wYj(YDa+e>O8<*1VOI?N7f5()S%$)I|2<@f%KkP19krVJwa649ftNYfm7muCISCBCf$KewIz5npcmHoXr~1a{^Vhcs=JS|ni^QShlIXsvbn|wT!@qpf>9J#9 z&xRgR&ufDExoWuT0++G6?sEGWx6_uyx%$4^ZR=JC!GK2ZuV7=(yowkYX&XdzkqM&v z{dmgH-!C@%RaYZDMgKiRSIR zCSTl4-MJh-HQ?Y55F!Z)VM;9(uA5!PQbDiB9L8excavGNCOV)hCq{u=qWm?^l7|n+ z+!JwuVHq8ZLD&|?Hses8R8ROO_F19AiX=1(O`g=NB!p;&5oX#`L%OLQ2$meQXv$0@ zN!Dw`-bOfGElNS+DB=KR;k4@9agxeb^I(!}&0WdZ$&W7Af7!c@umHC zhWBic%gBNTeio-aJarNl8-RBR|GP0or$jupM-&a;W^f59Ea(pnF7>~WWd()gs^rND zp+)6tvBVMdEs@lI3Wk5wgK0u_F;fMo9BGjS?}s8jv>~M);gU1bg=I8~0D*~^2Ptdt z^R=1CCCx%!u;?{i+|)#7<4={af7xU$+q|Dm$AW@kcq*^%x#_mqR;qork{Z_9D0AK4 zvAoxyEnAiqw5<=;rkfTRH`*+t)dee`bUhZXrv{s6Y_kUHT|2(|Q(MI^GTp6LoB8dc z-OM{KuQn&lMZG)Ryx-!cYoC(oFIg4aP3_#v|G&G13*pN8`dcrwK!Dv<{O3tL%*H1? zSj!);vm~Ool{f!gYz7dD<=3R1@pXG~>zdpaXG@xok4c}yHQyY0efm_5Eyuq5+^f$0 zmh@eu`l1LEIO01|qF0o;ReP_6EFR?pEF!nG?leL!=r}b3Jldt3L)kH$@8K~WBM_%R z93#4-gFnH>Hh0N+R(SzWeBSi1nAS3{@L)lgxLi~mstp{Th9Rp4Cjk483>R#%@%KwI z7|qI3AaNEQvkbIcDXhJMO0`eiNi|WdrbJ(A$Y-mlrt(c$6g3fWef=V!Zj{FvRTP#) za)_XXmPm63c_SujIo=$wi&CY0Sklfc2mNA^9J>EfAX6`1X`x-h1d@|cj&G!d4NXB* zd}h=@*CP^{0;LMZsf}Ap0ZY<*qEYM7Rc$T3ZW+e8!e1;sUALX$~R^|4J9U?)!+%z z$3^qBzJ%%ghmTP0U-K`6v>m4BWjLk-np_+MB_a5m;v9}Qq)S2 zhhpmCsc0_nFPW-vl2WoZAQ!|N87ab)NqiXRX0kv+|*Rq)TXls?}a$p2G{&~+(bOX7BW

T#O5Wv!K}XLELk#%lSUUjJ|ofTdUCC^Q>U4G2^18^ z3|EVSc#PNQT&2ZsOYoDmQ@gJ|9jMpcA4s@4~-o1Eu|~Tr687Ga^d|^ z9PAJI6V~TkJv2 zs>sKp=roNki4TEpe`k)!OI2VVvcaEdE(vZIF7ru0Q8Ik7y-U$28vWVv*#G%_LPip2 zzIG$9{@Q|lw1A7rh*joJs_y^!_#x6@0j=f$YF$5&g)Qi+PbuYN&@7Y+uS{#LIS0-A5YMbMz3cy8Av4Hhn78d&lVVyGGie~KcBZido;-#-s;WuT58 zj+m0RK{b;@CKMoARTQtsqT|Gj3J0ds#Fgf)%sWQ$-vR!tL&z@YLojYT4mJ(GVbGM)W47nLEUV+>YUV`*k~KNNMzA%RTGbJt`umt~+)h1eLm&**yx z8$k{`c4#s)|BX%?JQkgCKP%Fk(LV=0!+Yx+Ol9)tdnQ1X(EGLT5mCUTxq%t{BOipD zwWv*xXu7Vh4h?}7J{bl>qSy`+C|PP-DP&@2DXYXReyB6~CV|QgU8Q8F(Q1i=$^wN) zmqnmW4~~qkUzyYU^wu!8oxj%4Pfm{j_o>-xO^IGK= zP0iJVfkuF#ec^8cDd3VQMfDS)s^8zWQS;tmH3Xr4|0nPn03z?J0Q_~cQNswH0?NwL z%iZ(!o+;1&7Hn%3w&g*|BN}g5Q%BMr4Hwnu8`s-ZjKh5w9I-;zvlWwl91SI22}RSH+sx-?`+ z2?{jnltp2rKS|C`1YJaR!MnH78W4h+dr>rcTL%v=yG(|o;Pk*Mx4+Om2 z|GE-I1E=-;lHFncR{XJYAa`uz*#7GwZSJhr(&b2oDu6pd|i57H+qJHVv-6 zAgp00VAjL?5JZHx63;_MQ3+`Iy-})rEYAcg1LvwjWj{TWZ2BwiRYB!Bf zF#fr2JHchQ8}ym$)Le*{>htk@iLFwPQ=|fG&|D~PvRdvm9!0wSP4jKFSc$QwxHCkHAuJv-IXaF&? z(gmdN2vdUMMJ~F6U9i|u_IlFd?!uvydW|M&P=gwbv)!nFkrT=6DYG$ie^YF=R?5=* zCh6eJsR?~&9(d}w<}V2x?+hJ&z+96)A5eb%gV0PNdOL=d+ruJ(TuQ}~=8*B{npb*@ zOB=5?6I+xKaBAjomz8si<=)j+AzH*UBdPnKq1~s>yqX_!{`N94SwbENa}IDZ)MtC( z;GhWAtCkLWe<~5Q*XYDuA2{>3m4_7e?V-gqbIWqZ*o4vra`^EL9#gkil*D{Ym5D7sL=FQ- zPd+wlRtRB93hdM^g&i(nlzU=A%1n}BWVAhWHPY0zy1QNJx7hFJO)sXc6vi&Z20{9{ zcgia|l)FMZ|K1P1jD%F_8=f=d4kheMD<4(`YvfMEr!nkZFyP5VbXG;5qaR2Or$vSn z=PrbrcEM^@Qg^eZq%vHV;EPPIn}p#^9~wm=4WtoU#)!u)BNbbwU{VC{Ne>^*hD5HS zd6@HrV}q>z;x}4IMO1Df@`dlNzUWcj3ZHkXt!aNWa4DJneq~qjW_MLaNer^qcME0| z+dXv)u3;6ufFvJm`a=w1czy?^j7U&i)yYsf5rwHOyfUv#N~ha=UAdgxY4@9ssPdtz zy`5*tdNhh!-`h`T^&P$V8bk{ZuuRcysc`UiIFB=w?(6=(E$o?Kv~c276X-*=aEC>Z_s z+|#JHYBXwbUWypJw-98QkD}~d`)~=Q>413#i&^tJ{4vBM8D+P^wep}LTNi@xEs$~& z#nb8;5qiOEl|>!;#ieaX#P~?rjf}E<-hdF8N=>MO@;PGkU@m?`R9#lu%)Ide11F0E zOO`qj7wtz}DY5YP%JM>iAl3=Xg5`RABjFn5tisHshz%kw1Qn(^3DEq~e6GHw(`}vY zJo&L19mO8ZG7jRFRu1qL98dO#fn$p#e<(GMf9ITq*VRo2_Z7uT4xWrBOzCZWbm8Cz zOP_%Y(5`+D;+ zA96ie&zx({F~)uWZnO+Y;nXxG#$qCMNf6}<`_vF@|5{2GC#&*UNL4~7v{4mVT}r*9 zL%_;KT$TY?&zB=&9dzNFk+>j&K?wa&+^JsM`Q1mS>n%zJY%o( zM5eME4OH~S<-C}o*#)!P#kf_8q1+Lc5^?NR*(G%;h5H%AAQf1e^^v$hPJjGi?4{YJ zs`e@LpFvAtWMX-KMKsi=?v3mR7Yfpi6OOeF=W7PtAbC;Ri%KK!hMb ztDoV8Io^_Yprw>ErY*AKuj_$Wg~27_|@TVp}eyS{W)rtp_VT0_vnpFZ4ce@?6l7etkX0y1@i^a3z^u2r>0_BbT=Um^}pI=4X=x?!@u=2^O^169E`2( zcmVsK_gb;J&*lFU+kc0=DVjg%VuJMAoG{?*Yky0YrY88f5)(t@$$db$&f z;nqm$Gv~m@FWro(?DCQURgsuljdDHp2L?Thg@H#AS4I;l?a29|VJ?LE63Co7k;jtA zbr%J;F}t$V=w_~e%p!+v^(-{mh)xBLJHi{Cd*q@WT{x{+^~bIY^v6;d0+BLWx$9qu zJCKyoI#gzezzKyN$3n5B%Hr^&i3vFZ2?eKv?qe;VNih=x!v6})xPyF5BMu_mQ>*Tu zF6hQ3kV&-~#Z=AoIq!(_;Lu$dZ=;WbKqt$gOj&Pb(hS;sG|3uTQ~Zx5mf@#C2~M*( z_jT!km*rh+CMCIjjtbOA(KBuKc{F;VIiB&UHJ6{73o4;-8AAbuBk*H}7I+j2VAx%| ztq`Yd3@93xRIsBFymnM>WrksU_383A|V?hNAiqh(3x{ z)^ll%N%Vf-P3TKiVpgss6ujWQ3D-_Oh3i=kV&oO1Hfnc|h3~a@h>gcE{7{fE74Zqb zU=2L+`f+%c=hc1LinZei+?g(b&g+e}@gZ7TKu|x&^z1N$IE)Fe{2QU{Tt1l8-m?@{ zuILn}8;6CZ8;4uoZ`^Di)=KAGHbaN&R;Go_hQ|i3!c^^iu?pag=5%>&qzCM8cmM6T z|16(>8&Ed`RGgSYJAm%NYY$aL18d;!mQ`udJZwf0k=)L`! z#0&<&82y3pQz{_O@(a+m?!cmyYhwwbm<=K%_=~z_irvd1#0Mjn zh)EEs9an>0zffD+SHo1X7{31$jvy44FG}197D$dBSes`H{G^^@U}mW$*H>#-8yLg4}U*cWYi!cz-1poiv-NFm&}uc(<=9`8H+s`HfXI+Nhn1*GWVteij! z`B4OYbY+MwUJ(5f;tWb8ye15ffJsbizVzNq3adtErA0vnRj07|@e8G~=}f(GBxWQ- zxMjG@MtUz%S2`8+ihka}-Yp_PbUT<8}p_5(SONoi^JCTxEzW6V5b8~n1 z7a1PPzfT+g@0j|xKz=gGmdFipC-3T>(1&WuFvsio!#pU^g!uI1WQoPR5qSV^fhlBvgHJO@VN_27b zuj*71xz;#In`hMR18^9Xs-3trwEz+5^Q>BQR#jR41m)kGbDaoFx8PwUft+Jhcvr2< zn`a>oVBI1|OZ?ZV{h$3J3bp(*{!}#m`T3b3zE_3b`=+)?`_!{zH4HtlGIy?3GOnK8 zABh-fl=%~%0|Mp1052cFuOj2(ajQfPTWl5%HDK@_(&ZH>Yz zcB>%0a5x?kjY2*C6bvq?IoY`7{!T(!!*2sbTB?wa(ano#yBNlCbcvDjHLOI7p2#1r zH491GRvvmU_%W&WvUbNeXuelTUOqIUB*y|=#nVRA1fGmWn8=-9rj8d3lIaSFq~a*! ztD|J=!!4$P_e;^P*3Iwn8!PxZ85ez77&PC7PM9YmN3^U5Qk$`qYNKD`lJMB-L^fBv znLjQn*f!aLK0P4%+aF2zbXgE8t!cK+x0Bi^%<#*|@k^ps0f5y?F@hCKP|=18Bw3fU z|MNrwR+AwC@Iyr84@mX7oGq(dY8u;npWTOUd2L1$>A9{3QUat56CnwIE#o^Ocj6KK zynY&2sya9*+iEAZCG3AV80|sc$oOiezU06+E%TOFxq461lde-!NY>{h4T84ggdG=a z8HR&WRB?cxUPAQA-!TkdqYXw~G%VVm9#7*JM5J{LroW|^4D^c%YUpnj@YAIw;6E_w zlNq3*?7~{{e0jYqnEri~V4bqoUEN=5NCG6JiaVJ7g;6+P&$n_?Qc_}L8JtcP-GTV8 z@vVU1<%)HhTi#CqIcV=`ZD|eVZI1Pupe-A;2i!vV zN2p!GCr2kZ6RR!O{v4!noCtl4j7)qfQ7zz)qUH@ebM_R<)o*;cC#|}Duc;b4(|)_N zIO|QOjLydp_;_-wxQ7{tR$YWW(l!ypYO6zr)oxPOTwMvOXdrq=r%oFU62$B*|Fc^r z(X0GNe9mtgT(;ge&{noc!XMd<@HlAklRdT8{267C7{9E&r>0b*c#+PBee~xRC!ti$ zZ}$Sq!1#F&C-weJSRE>TO^hpt_SSIev$*cZt#qc|rVo*(Leo7?WklGK$q|lk_%bIb zp(Nxomg~*ez@?AXd|tQ3SLq#nzj-09XHloXSkTEOL_|d8qLuTFrYg5O<4H3{?F&b}5IQYM)6Xq#p#eYHO=Mr$ zsuINB4@WfY_RN$c4l+ifeD@e-msP66L*wsK%eU_(Rw?ut`}-b&@36n8JBRXlZyGE0 zsJL1oEufAYn3DLlnlLig(XgS1h4@$r9S{wN4!33a*%O$nH=d&u% zPQXBGbR)sM<4%J|5iCy)SjIUWvEoZh23_+8iB96*9M5M9yaS-}hN3c_pN}Na>GItv zfz;C1hl?DakLPB)JwPqtc%~xWYMXyCZkJ{aWvveNq-FqkryBoqtOplG!ilVs4H>om z_h-j@4KMq37Jv@L$EL;A0B$s*9Z?c19n^!0so|pAw&EO{>>IL@NDV9OO$0Y$l90lm zELp;&q=3pHYisos-B;ic_4Vl&t22<2cX;rf!_#%HPzo4dSS~eI>HvFnzYSGXcVSc*C$cQdi>j@gVx>jv= z2no3d<$ZI61^AE6P?0UM`TT_LSTJ%ca_dbwx7Z651Vg_gvdAcgvuu!<8HKIbHgZH`<7Y8VKuYKclyc!K>azC|x z=eMb@Im@?5vVpUagV)Mupq#B)w^evP-*6+fIh~>7;UV%PI0FsV_4`$}qt#~Z|9%$x zX>b7;0>5*D=f7>2^5j@*4W}1^!R-BIn|KwMvAOz=BWu5VFx&!ckrcG_r%NeXkpyt` zNlg#Y_DM;y*ioB^i#aMDov1tt$LeN@e0E}60svaOk)(NTuU?@v3teH}Yfc&12I-9MJ$E_T7+1RO9 zkDze0B-XN2S2-()=v3$ zfBw$_ng-e!xnIKnzE?N(bL^&D7SiIP#fi`?t9{#i5g9W%@3>mRiK#28X&fWHVG8sD{2@(%JN)9aV<-({9wh_ zt|*M4aLjZ{BydKk#p=XnCntMe>wEKqGNJtWE8tk5!ams#JR=TwqEhq3TBy!s$7_9a z_4G=ix6$$dkt0TQtfW7c+}M&lB=EHn^E<>LY?V4&Z7TGwqbbx&wMu2(_6Oz*l?5bL zc}a9I(66OSi~f~P594WooL5-r*FUJrMyqoLHd?`X1csLCB^6^Q@{WS~tc2Q!O*ue-f zF|4$-(Kg;KK#!j6P(^1Z!Btjgg8gVFU3nyC>RArkllY?Awhd%dXtz3s4Q(HQA=4;U zzviIO^&3L#69w z!Rv9O$?evdRPK*J5MbY0Dm{39%_b$m!ovEIPB->Hi`1)xDG(_T@fg@`BLS?UrXC*si|>EDQQGl1Oyt5hxL`Z za??r-9BRLUf+B?~M@o(s&w+5*>+2qMBOP_&!LU8@czUhyN_4Fv=d4;kmCb1^M8Bb| zpIK>XN^Q;+SH0j43y>X5#&5=k`GtkbmAJq$g{6tQf0RJFYl6X%oq03ScY&7qW?wQ4W^}9VH|3KC0xEr@lU* zM?^SXDtck1;!I}<4GFxB7dY%ubu$Tzj8w!3B0<9G`leW_c=GmqdwqRPm?U7nRPx=D z!$fd}F!KEToEa8241+FE3ARkTE_jr?g=e7C6@yN(92FM}ng7A?7ZuJaipm#QLf^R} z?r*=PH#ljUrY0x#v>7*#i;L46%#$d!I7vxqsi_0ZnZo=uk-7U|Do{q(WlM&KM++5y zP*R3rIUSy>k>T8iz#F8&e#6g^?eyfR?8GGWuEao&UoIaXr`kc*tR3CN8fF3CtHeb8 zh1Z-8;=pP#ue+)n9!Ex!349B{`yTxp$`YlS*|rNIAt6u_7VYgWf5o6*)|*tgNltW( z#v?Ce1FPn4uWp5Z`q4wtat)cl0!p~Du{AiUv2XksZk~~VM|rx z28LhC0j&PP{7FIyWr(z3QSEiIes5|FCp8Yc`Gjn{rM7aWcIGr8EKTcb`vmXr}wj?t> z`HQLOJCjff2S>{_?6oH4rDtSvm>-dgGMLKe+jFg-pVQfV!Jv|q1R64Wsv>|B2LAcT z@aWk+y{Yi0X7+3j?g&@qYH2xZHb;6gEp9ZIR!?05=I}6DWC4<*!W+|FGleo`D<_%? zic|e*DkuvZy~fIL4Wr5jHHzo92V#Rq!Q~YntKz zb%?#^;jgZC@g-$=-|zavzTF=XLdaEWav+UhXuJBkj{se*E&nddx8gMWtd{?NhqHz{ zXX^IYJD$nY@wjMcZ*PYS3u_JiNLMMx;Ep7ynjZ+Iha)O!Yn;^=Y3=`g8ea(Wv>O2u zILy8*Lcv4rcC^EqTO3YgCahSwSn(V?f7j29Aya6ix%%|pJ3Ngor&I}+k}FRm;17D2 z#^do2UmyW!I)<9+x(YaUvSG_XU+(5)OvY1tWNv;Z+@CIZ9RCSX?E*ZhfCJZLB7Fg9 zHz`polaXpQ<04IWYPGeAL_Qsk(KlVL!(KsPG6H-$@@!WohH4tV?p{c^O`pe({0K1+ z!EL5I;mCFXP0w7%bUolxq0tBc>5cT|QDT3ENM{8jejnvr&uVkl3qimQLBRi5CXOkU z-2=3g&o^US&X-W7Am(Pv#WR8l`|G0*(5L55 zTr5}TaavXE=hiT>USV>b{(S^6g?{~kRQ=)daNY&r)d3S?-eV6mJ}!s3BHZQAB0ymH ze84*wCBHjg_Vr>pJQgOv*qQRY0J8dE5f!($?CElqdLX0vGtQ15u$ANU|CP9xHo zyVCDTN9ZY>u+9rDmMQ^-hT)!Y)XmLpll98YhsWi`azliLT%+X5pcM8uD#3p&E43y) zL3Q)foR2&5FJ$bE-e-jqV=9%YfPMy8KwYt}J7+eXhj)>Hk~H3J9~uGE_7nVT_2WDc zYM{0aiJxit;s4CBl1jT#r&2|p?HLbO@f9L{cD_QRDZbB2qX`TdDjC6Yw*KV(QRaN3 z?GdQlK3`J0GTH6Gs_&1cxS92OAxhaGf`wf+*IQ@~ph8HsXu`>lvRWMp$ssj4cDSVDV=z}3 z)$I?Y$KW?wuRT2F#UxkB`i)u~l5}J75l%mQJxbx)zg~1FrlhXvc)bIhdN6Oh4pAD5 z1ah7`*(Exy#VU=*=b{{c2C!1=wMK-*QoR5u7kDhN;HQc670o;*qcK|yZtjDfm-{m> zAUS!rK*D07*zS0m5M7hm!~NlWrNrUr9_NQ{NQ*ipd#aAw-yC+va&ffT*BKFfTJ?IE zYMk{f_XQwfgjpq}N)-^rRs}fP3tlX@ob3QS#^XPf5DzOwIo@1D4qH7(Fp1Fkt{`3T z&e!pz_|h4J? zY{md6J`uexV<|wl_nkignmW~HC7!fQrRwL?J810B?AdRZZ8rVIa^2;c#{h`ci{0zx z!^}axk z4_;r-?C!P)_PeX?&HkZOM(%Jj@}*SPsJ#<|9eho94)ssZCtErQp$Hw%>j6M~**68% z?&Sc2?K1HM`|Z^AF@hxRvi<$X6WK5qmg#);jzJ91VE(=CN{N|{H;2VyiVoNc^LweV zh>ZJo&sxZOG^DjEp0&<+DkvVpt`2GW7hNO$T@F?VCbAfQFIgz7q&T?Wt@)JUjh>*S zYYkR)2(mYIGAVTG{T0`<67-B3)+Gwo>m9@g@~weDH!9N)=r3K}F2xcGqN0y{OrJOO zOhp3M_S2a>=+CEO?G8?`x8i~9__q&-gg&Q@|J6;fu7USO_JyR;YEAUvmrI5u+ok3^ z^T?tX+c~4+qy-C?)e82VnO8cleNg9UwbRlvbxUeErJ79YAeOmoKIdTN()g%WBN1*_ z-T!`+u-EuqJuN3kgS^$w+FQEd_q4g*VI1Nva_=cM8{6gU`>RTY_QoGU=xR)@V0B(Z zo9DN>sp2`kEc>c#9_$~M&CVOG09tn$aK?E2Z3+2OXtlJ-s?($a?gp8n3&t(IM-LR+*2K+EhrZe2!M41sEMZ6zzU=-d;0;^^eaAMV&)`s&n0t5pyQTCLXo<(a&B!rDH*GG$ts zOzsw!%gxh;(mI8?K?IzSwRG8moT7PC#~qWLg;m!Wmza_2yeo2>$O> zTQa{1PQO;bJNBkEGwk&X%tEm&J3U^Y!2A=ykHTYP0>tcOTEN=A zI$z+Ba(ue*5C3O*Q_8fH?a}IV;e06Bo7v(}XK&BxcrGfplN)r-+!Ct4%oT!gwvtCq z&qS?U34sCzwQvN*5)twFp5x#puHl+L@QfMQf!{O$e_u3Ca9V5|h58dwu&Hf7MwJoc;3p6!j@Oe7{(3j>J zpz&5NLq|Z!eY{lp(5m$0_$|+1G|nXXTQK$)-2{MA>j3bT({$c6uo0ol?RI;!2V*7L z=SS}heSr38Liii*d%fdIEpR^V&lVWc9ccCa)!zyLjF@&yAg7?a%&o*CF?hUo{-=ti zBxwuhPRq?^bAaeal1N7A`mrT@`%Sn6{4K|!-^^yD4sdI*V%7I z{_wRq|IL?9h#xdlD3V8aSlVz2SX3%kjRRLrot4U9VLUZx7WNe7qRdrH^T^C7aJxm1 z!e-@m-QG^o@ov%E69%(b>*R4+EKg&u0ism!#Z6Wslit|O=1qYF&h~LYRsXqPqJVI= zRAil!HLC4(zmPyKjfl%`x~$a)($lihj2s%PE^{E~rQkMd1(!a!G5Q9Izn_784K-12 zIC!I~T4#Jwke?68*yMxX7P&5cEgnx6Jl2=(YVJ?%&_rQe(CKG>=h$>SlL`6f3CA3= zlifUQ)Ecb^cW;%o<6E;JregoNo18ToyXxLC|6#RGLQ;Q2j_(8iGUDKi)96&D3=KH* zt+I%Ap%6~Lt!3DBid~S$!Q-;WLc=_YzMVWxW^rH?8OX{Cv1;9mmBp9D$Q8x01b*MI zh=)l)D)lYFHvy9!9?uId15qEH5TL?lF}0;m6l-0Gdl7Om1KLr{;BOUuI(~1=!zxIn z!tvx)a_E00HV^*SixKE;tO{47ap-7hFQcw^H`$g`GPxIoh08aU9^;Z8 z>gohmFeyjFz!6@)E-#lhj_;K{OSmM1O{Ye*8O^F(8=8 zf4%xOT5w|}Q3P)%f9P(0zFcxjkGN=Wk6)wn{V*f+>T2u#<1w=e%(Hv!0dSbQedG$Y zhf{NTz23jBD9GrMx4Jh>WHia1JY40)IRBj`W#u5HCN`LCcPpPXCS>^IUk>`t{jzKT z^VmjPAUuKn2{;>pWJtsTm$9zb=z%fb*Oe)2oN4^PVz;s6x3lk0HlLmEhLo&4JdOO| z%D?>of1l7bMYp)>ka)ce&%5KR^TqAWt?1a;=hF~? zhXGnRuNy97puM&bxDJOCL0tM-5gtHfH;|t8-t#d&0_Jb^Of9vJoyXA&$UhiXd{Jro7#<2pzRPGgLw_y zSD9_iQOXo6C9%?>S%JhGo1I?J9EQ-JI)D$=Fuve#?EvWVLs{`=Pb{s`b-l`|t?>RnFfN05nwO^i&u zx@>e*Ri6N!zrMZ!UlXND{@?+|)SLgl?8Dd&_oYl$tM&WK534o#GZ8xqC*Xf&J)7?2 zNg4ob_7ExEqJ_l*^g!vT94DnVPQ`Q3D=Wq?$6lF=R;p~ zdZBpsgU#pjVKMx_&8E<(^?n6OOa$w~m6Iu-hdM&RkQ>XBr6&4?X-iD~ z{pp<|xgzj6@U9!?uLn)%h(r0UEx+^OL`DxB zV1f4;2l$6?;DoGt;b_+<14m1DX{c=TrNswKt#_6C{;N((%TRn zy3y0V9as5tYAxH``XGPnMYsM3|Sjb}{VLYDph`MU%gp62+L{H0@jIfb?7MD76 z46WGYp7uzeU~SzD-VyDs6aOOOX|GOP5d8*URgHFFLlQ9+!9I zK<^elcWPXk-TrZ0VwJqC-ZwG6a0ipg^Q|WOYEO=xWK)+MIlUIL}-=LZJGUYh4c%yBd}S;c4oNh%!7PSx(b^ z7eZj-DZ`I%cga?Ca+rev_n-0z!mE1{=u1E$;HK)t%GezP@pCC&P4HKtfdL#$+|GZ` zHtH}SlSFGf2y8EwulK}M(tv7Ai9&~S`?Jmk?@CT<=%Xl%3<5<9Eb`aopfAn<+GDhZ z=3X=DgNGS?*lmo?ri7arU5ylOW52qv$%TsD-FeFy^cng#ug?G$MPpCe8$K+htH|?l(NEt<`rPv5X%k>yzlF zlzInFpyGFu=K*_`T17>vhcb3bRHs-mDSK8NHcQnme(VfFrc(rf6RXydt}ZwF$Q~bB zCUY8<4^ZqRJ`UmNJ586b^vY)sLXa3r7J_v7Im^as@HC!*05}dKPhB4N+f=$mV!;7V zi?w8aLP}yHG4a_1fNR45n}#wUXzxYp3r|Z+dm(E^NdUA+pwWhCgmRt-;A@vJrBl1> ziy{**h~-LRH1mQbf25Pl|FBx->Za4dDlVN1zD=z!h zo5PgR08d#4@~y{}RU+#QFzeCj`oZf1Cuj__585J+s*XaJtyYLa zXsR+zv01NQ0W&duf9l$-Kz(ix*VWKD@9PMi>X!^uLS90^c6v0HR(HBkC8F&LIKc9C zLJ+?X(jn2N4`s@n7f1i>Y@%R@?g^vlB;o9&gZH|n2=$EZ1Jt|&{&`ezUwJaI(^lq*_B zvs0X^PIOI1_?a;oxS}jOC_MY9kDFzK%$Z*TnY7!*xF%Il-tBg1G4SN#B6w-+BFeci z^QsWTqKlZy5U4@~;e4om6ku`?xiG(wpO(Tw#2`9#naPraemPzj1-98#*(cb3pH?22f-*yN+T0cjcLbYj9b z$zFYY+IGyDKaj3QQY=L+cy z;Uj+N_W>LY8&x}1i04pu)a-(SxS8M7i;?8YzgGz2Q3b_LfHs+DELlwLtROTZYAp9^ z8(Z?CJX+y3u-{t+BN_Hbmgtt@KiKgq-+}B)%X%;n#5TWzy%{K~leLzdtW_d{d zd_44jtF8HtI5bH+GP>H-_1h>f8NLEISg0{zL(v3j!U=DP&pF9kMA^>?hmVoSJVd3? z*KXhM!@5cc2`{V0|HA@6k;#WMXoKL4S$1P>X<0S`pU@%+eE3TmA;O;tBR?dv7(6E+ z3T96=mLT)<{vf69%8v$9$=)y&iMVPM^t+;?d^EZGGO67l{sfQ5BW&GUOKxZk*@yxY zQ~m;p5KM@qLRoT!=(hw+qGnrHvE4t%@o;6H$FW5MP|qHgDD`D&+Raq979_B0u*(_EpjL#eQ6GyE9%7N7Ses-V}@hNxF5G zEAoeHVS-QzrQcMurG^s2JS@qYt%eZM;ouVT64Lr`#|27o? z%x7rJQUL(q0!210WY3pS+t$Z(F)hNd$(Gy0xFUn=f8+aM_})KHDJbg7CTuoweso*f zZpb5&d13_mg6y8DW8zWFb?I=l+A|UfAb(fdh8#-k2#I;|kkFvU!g%1rgRU<7=i&$$ z&}N7nWX$jTDDk5|k9gfQwn+yOVJ$HLzFS&W`_VbA=?YWC3Z1aqj_gsffuRr7rrtj5 zsAGjr)8>xgd6;{uzJ{~|QOoCoe^v8u?$rMIz?<7{0E6?0fhG@06M_K>hYh=kjz-om zdsT;Xr0Z5g#FUxlvbWBe`b2nfD>7fTq1K=C+MD)uvyLRW}*)@z3m}9{+-?|?GUR*65#~4(SZ21 z+h08S_zSzv>lh*)Nm8G4?7<;byo-5YD8PO)=VPBdDOmh#T$Sb4q`~VGkS)*@N^Uj69{Y3_ zo8{0K#inVy6SPYxmM}XTfEz4$dNr`D?cq>~>vW)Ik|P8beSwYtv4@XLEI3~;W^j6r zh>9ADK9%{%#0T0-?^@s8#BZq(2GS zLY@!~TYUIJv)s4Q7Z+)hLr(_+0S_}=Ws`f(Z6j*$tJLJjk ztJV2zl9K&3rvp2U#B)ci;~qD*8kQTZ$pitFC_^TznRzk@f%&J)87P zJUo*}fVR_uk}F{qnEWimO1N&#lamgAQsE5`Y}RddhhA#ZHA3QHYb5)- z<0wg0%f}1v+aVjX>cHntNZ{`0$K@x*9HL@H(j>$sT!TI{Sk(XDt^rQGi(kG(tj?X_ z9suIhpzQDF0rtK|<`D*m(deT*Dq*6tU9YXIimFOy`Ziec6B3!$fabb<-iuVL6L+Ua z2&l4Feo_QPzx}v5Rjrv^s%Ys~Dq(Hecm>6g1|XP@{+4EZpCK0X{YB2?7u^k@OTlNe zrlE|tk1Y*D8~D>JeS%J_ad>zbFPNW*g8NZyvzrRn_+UU<8y-XmqSa;B2PEJM^P}F7 zl0v7@Yjm3SsAT4jT~n6VkZK9c^0ya|TqMNsdRyL^ej!ITu1;UL!R2DA;iT?!kV#QD zHUc8+<=Di-uo!#LPbR+>Nq*d>@+xc|m-b5>>u;S=3L02Z3Z;2mj^#KDO10HgHI*F6%gGZ#MZvRs4 z+XuG%C4vN(dN}-(rekVi%-~EAC8gJ3P8e~xkdKL;txACE`g1&)mZL93pn|^8M@}7E zV|%>!R0s|oL`c0ZN(DukZ-H3K{uxUF^l%WM8>M>Fcf8F43Z0Tf!y_r)ioTlk4OJ6q zX=e*6q{qu6TxYN`eAx1?k9+!GVOID+CAV242pzpb7`EQw!6qF~xfRMV)9_F?*6e5+ z|6EHV_6w!*=lhG6-O7QXiWpWxM^eE%qYpK`Qn`>E;m~x#sS#1w?RX{c=QHEO`+jJ? z&4%+Dj!pYWBA3(g^x9>!*Zzu$&DfNQhd2usJ+oh6mbyw6#mV_xo64N=1Uqza7DOr( zptYGY+sd9`;aF&u0E(wpXHY`p&7+>N5FO&fiL5r=kGI*{<&q%_ zX1#Qn>o?tT$Wb9RB`G87+d%5!v3#AJ&RgRtz_86*{hca~@B??zantRnRtI|IZ@|ES zrl_{p%E|@bq2V83+!C(iVC=*3BIBcv@G}DwCgh0fI>sU9Y&q;$j(Frfh;>gjA10Cb zoC7+no$4T&=2NHL6c+9Q`H0pZJXH0d!WW$%_BQ%E*TF^@sD3oN!}I<^2t(z~YW}9$ zW+=_QVTStx!SInh*+_KpROeOi#^-ZR`PtD$bV5ISR8|LTf>33lr{Z_D2@v7a|9fft zcM|=#TO_U^vaVBh`I52K_p@(RormclNBeD+-WtrLdg)t%1=M_-^Q=KhHtKe96WU5* zq9g}09O|CBQP@95EYln8=fDi29^91aP~7HZ5|S(DJ<-{wNrU*2x0+RkKJJQmjoJqN zy*&QqSHyEQd7o*}n9^d*^@k^|=+)$l&X)*=M+LBZUytd0Jc;n`UhUQgBXC?N3u;-n zUkhuI(bSg8|4~2S?Qj*MyZi^<#G-tKt<>%?etWXBT5IY$&W#ZDvCSXphy4aj4*eF& zlSSg->MZ(zcou24eGlVY| zd%|oina!QXTosl&>|Sb+xIboAlr;&vN@5)x#cp*JeR0`-zxdIcyZ8E76wjDvrZ8#b z{hH;nOk?xW{%q6!3}w*St|^#-bFyNC5MG9}fSOu2o&i!o5<6FG=vQ7#?a z1AHGX+nA@Yon2vIio>xHoWSS(diL3N9zHe_V$TkH#@CWwgvi6%sJ(VZI4MzIki&gL zS(N!UGlU;UM6%=28{2ap|NO%HDK>bT8AnWMw}A|EVALOgAz^G6cD8u2)W-J8Y@f8U z6)R|L_`RlJ+BbIy$}Hbh$7j!}5R#D)JF{XuJO3!l=O?7m`9d>5Y@i|qfI&5@DD zS+_S2z|>I8jSSXpU$x%t>rN7m%&9=eSBTWvd}0pIu)PdmPQ_ueSkAWTnrD);*%yqs zaDT!G0m_(kwz$JqWR@f2byHHb%$aOS($=W!>v>{*T~mPvp7uinlnt#U z|8OYD!{>n+k;sHi%v_x^mWG!0z!Rr2f|b-FDS9@p0ecgM`@a{5|FzGFu|a1MJbAkg z(=C9(9-GJE`04VGuE~)meg3_%>=h$#7p2aTUlPyL$|(yty1fvLb~HW^J)Xk9_0;#; zA3C7-5!O2Cl}A&OHcO?Hi_mZm706cpSXnn^f&$l4oOWTHv*e?~jqu7Fp|(_pjwrFg z2g`i+_@uc5p2V9C``d|pDgdDpMVUaJpuMx4kYUqe^c0GGcs+3dKGSa4E)`LAE#C^j z1YUwLwE9y%f#s5xPr?qU3JALxPH)D<3X;k+CUKsT@4s~wuaBQB*O*jkS4HAH?n=>3& z={hSW!+Srk^Le8(35LgQ8tV+}$BFDeXfoOJRE7T$+kf!x^90TXYiw*B$Jz3*YT*Pu zZoko5Q#|-q1%sHjvS!^?;oZ@CzB?4qS^?;4GAi56hsDT|%`*CC!L#AlMp!Zs{mIh0{SxZrPtzkhrWn8jfoih7B?sH%}5vg{HKo>xSRsBOz_uIV@~B*R{`DIwV>Vwex-uC(FTs!((W{ zUy!r>()Xi+tRDAs)nfeh()+TF9egs-W%K#`c1rCiW21)n6`p}nl|pLLai%wDiu-na z*=|pr#fINJiCAiwz`jX&s<~*PO1AMn?pP=Bz6PZFk<5IwjM7P|Ivhq(jW?vF06;kEqCFGGr7OTYO z3fZi&t zpmsExLWfV+hh}0I&+|SUJ@*tGL7tm%j&_H2^Nz42u}p3SK9$tRE`Niq4IufHYYmmAU}7{ zra9#&Dm9mJ;{e&hojk?+u$aBg-VnpG?iOdY>TTz(`hgd$t;ftf#16 z4W3Alv$?q#BJB%?{-+d$*-iO2BdkH8TJNmI&nAnlkHHM<-K-0l7;XXKo5!wCo)*i^ zzZ5EL1LiSwKG0P>;K=s>8bu{SzHI~lZrV3>BRr4mee{Fv14Y5#%!H8t>;3SroAj|k z)P0TdZ)vJ|4LHEP1Z zeJEXPxBK?W09@X8J>gA`=nRXUNx%opx-7M!Ym6*@KAlLVH%^U2XL;fB!iLD9qMlD} zp>WL=HaDj_Qmt?z?h?3kH{aaea<~gcFDA{TO=aGWg{lgQ|AeGPeAH#1r-u6+@~95Z66~`L=-KTfW)_U+i1stbr4i zfr!!6W(V#-pS-5LP=1ss8Ix+H^XdF`^JJFVMO%_Ve3Rv$pK+dFR?zh&nBY{E_h$Vh zCNhkw`8@E0$~_*L-q5iCm`U7LwyG zv!%mhDEh?=0fykHTed|k-O{G#LkfotNeqeo35FWCtaoN=TpQ82hf}GwKT11Zd7Lr) zCKOMwugS~f(Q0HXeujpatIf;O>u)v>eBeq1!y45t$cfpku>FaVSk*YE?oyS;)G&4- ztD$=WRKd30V`Wy0o7-|Oh8~lqCZZ;(<1gQYqO`Nj<|q(?a;`x)GVSwH2QRoU%OQzp zDG6D`Rg}&c??bm!agBwhlV74)82M+@FlhRE1j8^@lGFMm4|jJ_R6N z*lyM*EPWl^HxU0A{sx?R4sx@zt{S@VBzXo{eSJ>si%^oiuFdDN#%0OvEX|{uE>M$b zi8TRXLWTI?gKW7ZCM-4P88d!g$r~As2B;C!M}Akm+^AxHk|2Je!vYX6LR>0skfXa8sc|VxE59pOge8PYt;isF z|B|Wo&2Garu!HJ@(esr$H#pg7q)HNx(5vcx)4>NqJeE)(NuAXKJ(Xoa5ZPe6yA%@w zFPx+v5%&?Ko&?a0G6IUF1vEo{aI!v?pAcPx;~8vI`+UT(g`^{kDE)KshGf8nT%%06 zf@v;JGYhz1V-}dXhoS31$dvQytu?Hj$*^aT+RNG3x)i-0Wz8y#~%8I_-8$ zBn4HFG^fIS@*Ge9d`e;Zcy_bJ?!LFDy%GMWz;+O};C1cC{ z&`8Ki)5Alpzx!hIuQTh3tJ7?hS>k7RWg7dei>uig(~Mt|tF zD`Z%IzLp-H^94hyQqgz*un!vy1;Ym}sLN%)3fw7WrR~AnILpzr&!HH$4X<~NdGCmX`8jGRV3Vy)wv+5E z!uL7|X=&nq%M8;OjZ98*yy1ssKzc8p!1^oG{Zzq{>@!(xS2(k&23)o^m70r=v5*DK z>*a5;gNeK^voUHTs&)}*wj1Ds93dpJa(Y%StsZC`{xiw& z@Hk^9P&pa;aHL;KEpL!7@>p4$HlO=Z4tPQr#O*qc4@=25-NY?|KR0nE*7HRjxhTm5 zIuP)<)F`^%w`C+yukI^+{1EW4pZ2;K@RD{w1nO)QfD5_9r%lv0a6zECOr_4kZAsjG znD}zud1thrBxL^U4K~dP7MxO~&oaEkL;cmt2mf?Qv3(>7%7`K%CiG+#1=gP!qbKtt zYM^Wj0>b^IGqWgrM@9~}Bh?0ZE5XDuT>pwdOn#89pub&i0tvm3;pgW98c5vNXUzU3 z7SN6@C-6eY+ z=GGDg4T?Q)V|>C-|N4BQ{p-D`Z;L{r-=LrvHltiPfW>_V6f{<>40-NGwe_-nUUe%W zITB;SR5;(z8fb>&^!~}K)g35IEMjgIJjatu)Fxs>8JBN(EYy6RRAz;p)FND(q00 z^G)U_vkn4zR%!I4xZ=v5rDi zR)RX;Xs*NcEA(3EqEXk>{#3$=BvI<0QPP|Z>Vd9Av4GSV*Oz>aOBLY$ zJSRwks+w6MpgyWjXeCb1jDdyU4TC4fL~AS#2oC!hmV7(t)CNADiuyfea#Gk%cJJis z4#=7*4fMiOFL5Jr`qD^mL~1&KybY!%Dams)8u;@fRErgweovV)7A&Q@bYBj&2kyk> zt@d=*XpLNv(A%RkwTxNZthJGNmXT05*y`KEKyGqU*0gxyq|qMh@}q4Fti-vOjN$-F71rOu zILhoyL<}r%<5BqGaq*RAJmfQxC8!}!;YIlow&fLEmGfoGZ)eB-E-GMVlc`5Ti}d9f zigh%MPuUp?q(5dU#Dl(@#AxF6Qe4Wp7wIa~E6*s#zCA6WelYvYA0$2)&u0%;oH*B} zu6v+0pYV(@rG!Ko;eNNpLhXVkob?b9h&gd%nMi8i{edJqWo zO90>7C2>0*TN89bzg{WDR%=OH>N{KlX@JA7;v|{J3o6j4DPaHw%`*!3zq6>Dn1Rt8 z60s8}jNPYB)bcA(s9=THQ*t*VsH#vEWdpJ6y5*BlNY7*wcMwA4gB(D)-Y4Z~2FAEr zPILb?xBBrORuuj6`(ypk2d+i?(Nfc2wRVs1z$B>nH6n*BeAw?5v_xT4Lz&_H)lgLL zPCx2gq1M^TASjiPtxWG|Z-OxG*eiJO!c1y~CyfZ)NT*H^1)D`jn z{6Z1-NS>cQm0OofPRrF@6Brh(NkXH<#CjH~O9EtOR?*S?)~jWR-ItYIve<{UaDO<= z^MAZ6&C3Z2jSFuREx)mA_l6T{cQ&C@jPfm z2E;Ft=DyNnF*BHBnW}(DFz@Cxwb9&v8k_c>ia*0_eTuwKGXN9+gF7O%ng8a@ z8{seQ+)%Bwz^qOlvxx=(+;iC|_+ep==iMM^WYWZx2bWnZ2dNTS3= zh5Dy8o#EDk*MMM@hw?{?JOU&RO@C8=? zQC*b9AYCX@%RB--b2+S^k;Hr+%zzWuQN7E{+rKc$=t1^`Ve8o16QptjQ*yJe^G@QT z&RfJq38ZTnfQQ3+# z;Qj9e;e9lq9LWO)%fDA?h`%q8yTbq9fqxcKY}vpwm!-D%99qgpq0^>a)yNIL)>n<# zYp12fzH>=3s*@DSTH+M&3zGOHd66OU4zshVsFY?qt);KIN(6KW=)NL^&w2zxv{Bw@ zPwoqPXJYD%@|jk&mCG|g0{(6x(&hL<&)3@_@}iGNh36BlCB7qWFUHS11Pn}akw36t zd8^G|TRu;U@R_<>IO@c)7XreZa7xv0SSDg%^a!&guszevdu}X=|8!hne<9ViBb}lx z)p||*UPgGjUB4c8inQonxP?B1XqtSc zidZzHiRt5=a4NyD$8S& zZ&tp$bxlb6X3YIb5=2icG|YuJ2JcUvfl(9yWEp>Nxh}UD6S5=3gvHv1_F8Odr+#^~ zWsi;F4k-jLuV^l=z!~giw3_t2q_@(40oNqPpxq)zLR2^9{pFQ0UO!iE50=&ZZOWp- z)%u@+U-ozBbMm4W-ET+;^>;5LW)gTl%$&eMGa}yIR77DP%l&(L|M{u7Z;lh!(pmuX ztqhx#TrhrHW3f@|!8D`Du(Fj>5 zW=&xEheTy0v-j^JY$MD2XmJF)0~TqLz7+^3PnFoC%BVL>^4uB+vtR~`JWx7S&=tKn z&oxo%h|+J;jMbN=V%PY}D+gA>B4Jm;d<`~)tl&Ka*QJWT8ypJ$ct~;Gc0W#rhq>8Q z9=!j#0X&;w;cl_s zq7T3v)dwh!s&qrkf4KB{^ENWkFJW+Xb6hG&EizMm+u8g5+kx0_85;D$L4Ri&tLTFH zrb4?m_iR=BtLuE<=Y<;McC&%VH%}PB+E@_4)Uv;SpsQ!UI53UgKjntB^J|_Pk1Nyq z`gpz6G3zI7lCc^-i&4=NcZOk)o9xwr-_y12Y+DS&Pw(+e0A-|V7e;&%N2H^!W0Np8 zJ_;bb^^*p{N?$eYwd!AG)Jzi?Y*Fr~zrD?wq9`p~Za2G8;Cq=!=*LkpXmS3ID{tg=GjIA_ z{>=VfcExz`G4d4gbcnzN0U7`5uApH^24>B&U<*;Z06LJmc12_ zEL1f|Fir%DG@k&B7LS`Icec4Jo}PadT|SEB2x;EaxA1kcRzL)ydDeh86dN1|64|nO^@g| zlB(FiSSE_(W9vG0S^MX0KFlA*tjFQc8XqsfRb|LHJxR8iKjFt%ovTVBm^+10CF_cc z>A|B-rD#g=Ug^E24y}4F=}YQ?ZMHJ|DHo?9YzNXay%-npUekCMiqkmYoJ%TZ%t~8Z znrrZWH3#D+O%R8nmsqSNhbMWUJQ|q$vcGW@_o}EbT9PF(;d$_8vDUC2x@`UW7dpyT zT3qZh<&nl}D4dXFDSJF@L8WyH1u#BlL#m9Wu|nr9?1ShjqzX=J?9EcEsj|FNDwobV zeeSzSUMfAJU!?f3DUZSL4?G;wQq+nIrZB+wmSH2e$;qiHDH&$m#5_v&n{8f)u@SQt zo-}QUii#K{g2BYJ8quN#bd6q(pK_QKAr7#mog4wH2kVH;wc@muOTwL?#fdW-^w?ZP z^S*;>{af5NQgO>7zViivvI(^}tS9ZaXH3kjtU~=5Xgs&!thn0*~T`dJ9h zMO3vSsWgR66_QDqA;X_G%Jqk!cfv3?)x(?nF_pZ@$s|pW$o}54-m7c8)t`5C#VWI` z9xV9;PZpNj9shg{k{sor9vM-{vgZA>RBfIV&FCH*>WWZa=YCpc+|^~HUmvBVp)tqF zciM7(ng9KeZ?NKWtH-cKMOa10alsuFaNj?A~k>x0IB`n2*kh4f*V$Vl{_-#=_r;1y*vYw~h^)yJ@%+oii_5mjiRCOwb*-Yqg)So}S}#aj7~jXfi_|Cjg8DUXt(efoh556vuHIUg5~7Jxz& z^64LK|6~}e0yl)C0B0J66sF?oeyRnjNWMCqSG3nb-?xj1Hj~_b>-M3(a-q(2jL)x(OY~0&(l;v>N_@vqpaHN!y9nQ>DFVosJzak0Q~(^GBG_{h`>ct%AM6TeI0HBp_7P&uAc~ zau9_Dqep?%3!;PCM%*8U&v<|*Z&28F^M*xJUiS5WE7E_YQ{7lmFj1Zp{JYTrU+L3o zoBPCAM$fVQRW5|3`Xq+$-s6OW9CbVqsmn6bR)d*iU*E=R_GaxjkL~BclID`=nN+$F zK4cZlo1sPmcd@RBss#ja7R5Z6K5Y>8$}}kJx#)5^(=b9UjS4f(%Eiy>+Ca<9{ex1P{e4@m*+}it-Tfc$a|&E* zMk{chU`jM#hox$tlm1#Xq-`b^M)@}(qWkL`M0_@*u}^{Xd~~hD zC%g2LVA!A}VU6^~H188don?IER!=zv|F=>L{%n7UpuAGudny|;N~85hPeZdQeD0N? zN+--my7ZaoS&$(uMjY;PK=QsGfTTOz8eXp7WtRd!!p6ep+|=jW05l1i_tE!J z*7oQ;osR?I*Z;K{aHopT&GiT5cWUIVzyNDHPMFjGH$MT#cIjI+(1}N+v-|lfRA9hW zlU<~p8N$vvX;h=vYg!Ur%Sw~|figp54O983RnP>h;Ix8i^Q-~0w;!JS zj>G0+i`c147k$P5w7OZVT5o@ZdKB1UKxBW*4>w*nw6s78BUNcdN|g-MwP~c_X}(Kl zfRvF+sS8U*^Ya=o7?=c=wFf2ElC0&J1wtihO0~Yt9Ox@{KgGXlTwc)lBaezZpKkW~{)%$_Jpi~c!9rhJ)Uwic+5bW*md*KJarI$7Q&l;cqma1DM4IHy9ofx;Q&+D{N zZ{8ks>InYxpiqVuyr#InyoWf8pS*;kHElGnDcM$KM zkBHez%=shnquUF-!Zi6D;*F|6%_sR-$&dtXYq($*qJg{5cwU-ny#sgZUud?|Ws^Q$ zOXHLrTL-3FYX@SSRFo59BC0b7qRS$k-?%}Vicc^Vvf%u=dr{5@rQ4JYwPn4*m-G43g7g7>+eg>kjm7M!tPQO`95qaEl zytfekUq{e&!Xh7qQ~%Uj;9V;H_Y+ZZXxYu&ze`eN<8x-ewCu?p;~^~7Fd!2ybOr9$#)jC zrxAI=TLum0R}~qei0~8$he@QIP+0VIf7&ATi-msNN&6`>oafE2=w75Wy+~onuWiAs z%g1zBTLQW=Oyu(#%~N(2$6x%GrJ)U#HNf=t5NLh8Ud9D5>(u3E;sfgg&l}INvc>(B zcu&qn&mFdCHM1f!1)Z-+&-nvY4vAVZrCOS#lE(`57WD5Bf}dMg{C%~tW8u7SKOmHJ z$wfNmkh=KOJ0!u3-s}INDfMqk4FX7?g0my;Y;Xp3^ZL@6~M9tu4E?Bvq>ZejD z&e5VYUrAqjqcArMqFaU;KItGmBV$Q$i1}k&^gbH+LERmxn1z0_gsGMmcTF>yHSG^X zO?+}AIfJ#z*%Lh$1`bMBqPetVUyD*!DGyckvbA&i2;UZ2EBz5^2NW3jNJOqqWw`yK za`e%YUbs0psgopkf?gf@UBc)zD+Q+fHdd8o%rwqs3YNB^DQQri8o>x0DBJ{+1x9GD zV%o&5p?N=RI#wRmV9^c;k}aq!e_04A$SurRR^vN3#85B)q$bPcFSE*>n5^oqz!T#T z>&g(1E){iC$a-Inwl=4ddP@>}YzLF?6SL-moGgl%KkrR&nJU(oCY-^6i(Jr|YM#5G$uCTy ztY^bMkj(>+)tFvW9Sm2a;2FzrCE!qwc&FF9orPaD4?na|&(Cp*7DZ-^r%t7$B79cz z*4{4S5E?BoTM((a*<>qY+K1`S>U9!{SBV;&{6cV=%b~_p69$}|8m0xr1K>5;Hds8x z@NF!$el887Qzf-Dj^|OPBiB>XZ=FRPrS+Ow69Vwq>+5715$NWG2a5cWRTGtDMt_S(k8cR37 z*yh16KgI~FO&?o>(vLy%)%+^}5eXFqof-jOZy+^KW znqEkqKi(c?LjBEp_#;g!5G%`yN^5RIsmIj9v|VVV1Pi6Usp?IM%Ojb2N7aSo^W$%O z6L;79)+y<927QZ;WZ1S1GLpHt+ElN%$~ci~s5_9C;e%0VxAQ5o8%F7Z4rre7)JisB zvG23^jz%hEdZgH5&tD>v39qwXhRBXxQ+H^|gBB?d<_cDQ_5=hmX?pKXj7V|&qUA4B zoZlhrdNRJPEQ=T$8v`7sIw;f-c~-+VjrK&E0ym2WtV(yU+ocbj7V@RaIg)w`A~;wx z;vwl-uvNdy^MzlZn6x6bhxzitwO|jwgSad-ZYemS`A+98^8H! zhX7}-4&`%MVWnAmPi!uEO0gNI?kayEzJpl|pDPYxEfc20o1z zP+S~mqL#&fI`t$eNS^reZMoFX=_VCN)$XX!{b)kkKn#K;awWYb4m_u?syp5h zE$%@AQ&G8a%)Gz*Nm)h!CCcckFaDOOcdfi17Bc~GZx%~~W#zgnI91Glv@RkZU zg%7hofAHE%04(02Uj{}#=^}Vj#;Mv}&Mlh}`Q>7|4 zNRoZ&!>Z~PKVI^pRu1#(Qq;6YLMfAW45mVBiIt2<7U{{>_7Mc5hZPas&Y&`a_3 ze^x*a78|5u+p&^S|YSK+_)~5h%1?Ccbp_p(mmeFyF84 znj1KOYD9L5%Y`#NpUjP{!RS>JS7EWNNRQg3!^z(Xn&Xwr(V+Y;qsY1EJk@~vg7y%> zV2msVzfE=Un5w#{F7A?i!R>gX2UbY|ddImF>5RFwy*3XJ`h4t?*31H&AgcTij0wr^ zO>-#0bdGQpf$UQ6e>CAvC6QszVmD^m6}&}O(m~fBPj>uTsH)&uK=T_+h^X?_MCYv@j&f99tc3U3<=FyGgE@YzYV5T3*H@ZV85tS2u6s)3JLa$P#hGRAjqP}XQDmn z(a;I0z>+jgDGTKB3se?XG4t1^_jzGsW{$>AF!~zd`TvW7IZAQa27s1BQV*6*Wc+)% zJD5=ezp8S&fYUbieWnOqU7e8yBUh~8b&9rp=DLnRh1 zArs?0NHa`TPo%kkLT{xqc=XG6jTXV+UPcO=)HhnQlayx2WP~Vj_#Z#UM0W7tpMm9r zFO4PY%wB*Q2IP?~lH$jlci`&#NEoGK=o<>Ea>bwOqO-ILawp06ZP=NsGa8xL5>n8S zP;xT!(>1gil>!yc_qks0!iZD?se3(|%usTb#HhKX=|}RAeBA5REWwBalZ_`=FB|NP zp#%=!If*sTu(Tz;GBtc@ga@!WMTYE69+lB>ZzkhwExak|u7_|{eN9nJQ0nl}ZA$t| z1}?af%DZ!42`B#7#Ctv#Mo9i6IkM8Hn08EvWQnV4N=Wd+z)wITV2aK)XuvMl0wx76Yn@Rzu}89Nb=Z?<C)**PdV!iGL#GAC%pF;vG?L6DSaXmuqnss0N3J3uSa`evHA&|=$0(KAogynfoywopb4Y)&!oLyI{5suH zV3d1X897pmCE4v&L-UNA-JgV;EzqltgTdp*5X})Cx#gqKcBy8)#^2JLn-0hani#BR zO4h5Y#(sO8WpO*U0dNgKC6u`9BzMVj)$|q3^BhCDVxc1{trXw>Sf>pkr=uB_0V#Ar zLw5@a*lT^?G2lu+*D5$O&IxDWrhq14)Ynel*!##r5FaFJ9BYi!`~jI^mO18Z38W-KEmvXmC_J5S{LkPxf7k;A1f&yn8-XT4s zfmUM3joEu&SeIJ&SPmzd9y%fXR0Z_mdk@3UfH)^WpQ_~$=l}U~XP~z8b_JLkef%6U zm!2ab|CYYkPVBQJtvU0pmFd3F4l*VIRf4{tMX2$m_8X=u(mR965rNEv>lFS(Qg4Z> zFUiS%@y!Lv^YibJ9ux(W`x5_U0T40jsJ%movaO9)>??i-F3jwFL>BTv}Pgrx|Wx(lG0Wn(;#h2 zBsL&bg*fn3w|Mayrc|%~1iS_6eCwXGXU*%kd+I7)cF41?-F6iP_LA-QW;bC^e}0`# z2noRATiNO^&)Lp?>e^5JN3w6!ePHHL?9AC6!#VDB-n9HB~Ce_;xJiuLD3bTJ)7xE2FOZZ)g z8s9YY1B`|GXtCPf_931jC1zU10?vPZU2*_@Z0LMBKvLfcoZtW6zxhm>AK#w=Xfhq4odliiqv(7Cq^FSBM7)0xedo z=5lmHI$u|5uTXI&0K;x}ywFC)2MIGm<+Yk6CwE^T-A{8Lp1A%5T>NeSqwINi)@Y+OYsuVD5E9h{8z>`(#$ExyEt_7G4UjIv;F+wE%7XB0E*HX`mvbD|X=DB)6-uz7 zsY0)~&L{VI>5GaIszg?OUr_A%x4WNIzrS^5Ne&x3!<(h;A+;0!R!uS5gqSFh>fR+G zX^#uO)hj^to|YjHx!?MhX6OO1`dP-1WZeF1w}Ve!i;Ih?+I9txz39ttF72vJl9nTE`4$J(95l|yLpMEDf~}J*H>~p&X9lOnouZG8K>*O- zh9vaAzC1VaV34s2dDgX*g zfy^Jn8Wp*yZ~R>sPkx%E$&{r>faZLzYQy2ya5<`WH*p9orbe5cR?#E?WSnro;(oR ze5l+7C{KrK#Vgid(O?lh;iE&0ROm!9kUr->43PyWWk$O>XYo3wvgkrL`Xi)sG0^OW zRS3SuV~B;#&df~kzg(vod(lu8des_Uafq;ZbAn}5U?;b!jv)lxqK%g< z4f**duV3FV%WLBRFWl3z-`&}AW2yR(@`#eWr+11f(|2(g2k|sA2NQ1WF&zT+>13q% zgc_g?XZ-o>$G1>4+tm_BVCH7<029lz0aHkeWF{=VsLA2O!ZMn@Y(H&E%|bMaSucjphI(CTu#(Y4Mn9uQ;_TdGlZ(fMC2b)rrT)F(BtL9O0m1kkoczQe$_~N z-K`m9q_>nt@#`m1jl#mhg#87Ai_JW(%z=C&3*N7uj$ijDn_<} zpJM&m!ge2deDzOn!&)WSBvD*B8&$j>T-7!-#W3_@&6{AO7{aC%8SKduJ{xipJ|qtR z`Je~@DUs*gi+toID~0wI#O*nS=1-1Atf#9hvL!j(Zc}V&klA-b+;5gOWyD6DE%P(!EokZ61h=FA{dHkgb6*=r>!apDFo9nKTu7%A%=7Hy6t|3e@5p0&f1FgSwI{j zuB1eoGfb7AUZT^`MCj^QM-mSQ2|>k=Pq07)`R|R6j-r}ZjyyRLhRB!vq{4(MtNQVo zF+@#a2+^p?aNsJFOET~3BFxzJ8^l=ZJom%3F;hFjSr$7OU}9Ym=i-T@j!$a3=a+=L z^35r}D;foj7wzQNY{AhO;?GV-hT!C%V(9EkVI#>i#_l{-KDFZSqGmFD-=b|4Y_g|> z0as<2s|ECK>s)blYmB?uKHI-#m5lh0kzbA#4q&zJCp-S`{Ckkayc~*@l)&vY>`lz= zw0;hwgx9RiPn$fWfaG6?=d5AGjx$2%HmK6Qhp5BxvJBkwu?#NM+q2#4Z%j2UW~aAj zfN^=F#cMV!i7wBfh)Tjry9*#yw*is?D!Et^?j!gCI-cwKUIaGLS6`1X^fUusmqky_ z_q5d00^6N;U712(mPb-zW@g&lPW`tL2%0`+AY}S;Q~uy~+UU?W@+Z29V(X~WZ~nxI zf~}mz=k#~$x6@S2u#B-Nwa>m6EaytP-dbLIK0^6t=eSG+N5a!p~Z1mD4Zwe=n~t|AJm@mmDAadEW@W z2BC%mzu{v04ZYqAV;N3e4_2Q(otI06)oj+m?%|DUlh6ytueM&rL4{E!|(A%Xg^t97L7O|W7$A<=xQvZaWIB#XuH)q*Xy7rr+p)j+xD;a zRrB6q^VkT0b>=6Ay8YPz4;61!uo{Mh$QY03@m+RCG#*YY12DyhF`y>_B!?KeZHJU* zd)pOyYBKx#U-Tem=d(N9Uu<-Ga}|`&marPv2XPxWpRcpET@*A<2sgUy$_9;L^2=K7 zN2cIxWU-qMMxU+ztdB7VnDGGo#pB%TFosaH!*VoLK2-?+=>ri}D}8WKj6!Qwa3)1r z9%eB8c{dP&@vd|0`3m$2`18+z?IVZRov{(O3ZEqtLvS1Z0x{QFa`t}m%(B=L{^A?P z#-HWN#a1dgkKeeJVj`+Kz>Qx+2L3nK7xd5M)eT!?f82OYmR~O0DT)p>6QBC%pZD__1UsbZB9Mw#srz!F7nnO9 zc;q!AH}>yp8BXaOEfYB34(z1?cQmkM(voKMItnbiL`4MD3sqa>$V+&PjuNvwlONMX zQVQIp&k=)aZqHV1lL4ypHJ}ekiawpLq&t@1S0`)nUw=$9_KqN#>V|zoO{xn}K>_X8 zv+ecK{-`!!ODzPQeL%hVl`dNC=W+Y<*;kWI#1pfy?CyB;wTMEBniRxoe-NEv{(>3D{Y{^>f*z2@%)o;AR@KQ8ydD}9WaTSyO1v3|Sv;mX>V##$f&TtnaUheIR3{S+XUfQ;WcaE>md z0C@iW4$qq9a@}u9zpO7e%3(w-jc78pqP%%a^2@((a8m5aZTw{Z2^AesrpWq>3I(zn zINyCa7riXPky*ORaNjXECVjd${$SnFpE^$B5=gzPeRUYXVLn(5=#gID@v(0hw%-9< z*~2{FM`Y^>i=Q@^2l`G}B)+?qvq2%K0x{MDBUyhscF|p9bp$-?+UOYfokUJc+G@So146@07n;hQAnHAneIzl$`tOEk$L zaqmny|Jg1Ga@67(zHHZMO}(2cN;LDg?ehGKUSwhXC6w$tB%j21pxq`g+FmXgj0ciC zF=2u?m?s>1|080z8gusY4AfdE{Ajw7>e;N=1FLcE$jJjwlvQdAd-od}6<%6LH;X|* zE%u|w+ipJ8=Wk)S3bGfJz}0)_7m<*Gd#CSrk7{F6+Y z!(^7w-4_5Z^!oHL*(U@c{rD*bSyNbyAzkl$d;x(^!oUmXo#6LC8Yf<`;cOX0VN+fp zL`+qq$pVrXf=ckye)SVT6i_ph{p)gHx1Hw)g|xX7mw&zbuo>0&Dv(Q2INqs@7Tm-$dE z69dx=B+h@X-ds-hebjue2yg)`m;nkxuFv^H=hN**m*dkOgpAp!DO{MLkkI)VW=%XD z@&0|1&eha5kapuXmM&JS+j!8jwg&vZRQbf6gWvPzJ`4~X1wkSXsQ|EMw(IiuhTfe5 zp!5M}v)4=Xy)8nW6uC24s7sp65p_SQz^jZWZSJ1u7SXNM`{0-HlGlt7N#3iykV~;68=? z?eY4&)9$8Xanth~zqOTReI`-9c9Smt#F=6EWL`NwJsU@+DTGtOSpVan1IV9AIut6W= zrFJZHyMqh$$RU*xUs%)!{&xu!u=CZ5$_A7KYif2QZ$R}S3aJn2xbV%haeoxF(r^^N zWX_7ZcX*;t-#LC4R8v?leLqJn&&mJGCc}wv`wXQj>e#i{{DXSpxp>Ecd=N4j$phwY zjLHP2Yk35`vG?TVpg(cMi{^&B=2$dg6O?pw_x|u*2)EqW5iN+nQ^iR7f(~8&q5~|3Lpd1ZGhIwvRHn`9JIz zQ-2IbCo|6=)1f)!X0*D@hDQ8RY1>PT0=@H*?M`AyA<_T2mqe=H=IeE;RsE^s?_-uz z7fqqeXHi{44+ssAc;Uad9meOdrjrrt7AQa2`P%c8Y|X2vLIwnkl73lRZUfto;IXwH z6{rIlLN7-)k;fZ;hlW~?K;FbS7hI99^?3T_guWBU1}a8~DNxjznRI2=C&5N$j6wD{p1t6`ANS_!RlQVl3jePv3@_VdfNI=lb6m4S~ z(q*7xC3b50ie|V=3tqCEdCz(R)b3vU``R97YY>zjvX}cIO33u?VKm7@zq$(0HMM6> zKHqHKJo~a%^R0NPao1&N-D@c4(AZg2vp|CiUBG2HPwv(Bpem#2Rc9|ZE$dW3$D0>Kw3esyb``O+YuaOQ`K<76G*^RO^WW3mIH$EbQy{zJ<4$BHZZUoeuCXxm z_aFb)CxGZ3ZrDG{>hpMP{UN}2Y-Qd)kdidwJKkv8} zxh43$>s%!Ffwa?fy$z2^yX|;bHs<|DWuc2;31eRl(z5X^-orf$MPI^!AOut-_vT!Y z$Gbl&=$SDW;gDqjytkg2mC zSRo4Z_D2%HRFfYB04Sq%-pkX>eK8$^14F3O2_AASw$aQLN`+wOL7$yn8WgPkJ*76W zB#<*ieCZ4ie)|czMGxP#0lautgo*u_q)NV&E&HbaLY0DMJn*R%HrWW3C4LY5=OD}J ze6}<$FBN*$Om`L@UWfYxLVwQ%0X9K`-QZUc;(g za-?fh-=#KvzVt>Gj}Ajg#7qLq?We^v$a?n|*KST18Z4 zt^gbu&`f%viW9BSQwT?!1oDV-FV5jzRYcxm(4H=K>VMIF9H3wahTS~}F>Bc7GlzEK+n@<0IzVNMpx=2J94=8 zV4PIc41QVV{yOoJnvHyPV_i||fx_oYjc#M6-asY5!E2gTPTHSz=eGf3aJ7_@ICl>% zwD23xtiobR`;%m2i1S>b5bfwTIE#$8bm&B`I}CJ5u@>e0nyE!pBBh8l85$8E=lLCu z>bR7}3Jsq;Cgw#^=tiyaA2a4WeKkixBDG)sOW`y%ZIj)-dz`B1)U7-2ngKMQl< zf%RO3Ggh+Twgtq+?u0K9BuG$eZS1^yQPhVjr333sC^2Q1`awrOk(VgBlF0naV)w*} zjq;J>mVoV;RaqQzW;<=%oWbR&0IhuCJ@KqT5uGGLV~vmc-8Zzi*R#;QQL1@0DEK21I1 zFgrbG1W&(%xr5=1W+kv-W$SIjq@ppD0#Po=p6-*d_AI zpzH3uotW3rqVpoK+pVbdjYHt@$|0}g3OX4tb1CIF+C&EuP|a0wnh`4_YenL46-+Nj zgRNp3i{pI7c=&gzhyTaaSpdbkZBd%w5ZocSTd?5n?oM!bcXzko1b270;O;IVxHTT! zVg5Vw?z~s0sHO@Ey1LGPWbd`U^DDwq7>f?K_?*)k|7!w+L~kJ&7gO0RJ?+;wo%nH6~Yy*?kow zjR$1fC8|xFB_7DH8~t)zrg*W~Z!)36a*IT1di!{ypiggWn) znnmQ)BhtNciWbhZl^&wHC=@Sfi2~`pv1i%zMZ@3RHM8bp)a>GD1UN6nU| zrcRK__h8oL#zw;2L72y?gEfsh`-NdSo4v=Gt37v{b@6&~Ay(#;LzAMw5u)78G{E@6 zf~DR_OC$maAc+auZC~ya%(I7hz=#5+9F^<1{5QMCR7%eNG@mx1mS%hGqQHOW5{*!$ z5jngb;-}H@oxNBtwvN~PrReh>id{C#`!d-`&42#8n$FWo{eW!R47?<@t_q&j>KU>nv#)Lo;AN@(7R zEx9ktBfun*eK<|wD;=eH-tRejaNkPjNXMv1yE}z}copB1queczy;$X(!zU>?{L=v( znI0)B$5qM37&FgyPm|hl!#%wO8`1=c-Ixt9K0jYm|KT?783JyTeYa!b=8|K96HbvJ zo2lPy+U{r63^WPGd2+a*H>ek_Cz*3?n2MVWR%Y+94&^cn!nuJcC&&_okD#PNY(|~Y zv2U7No$i-48=Z1u#f{&a06ZioUH{qAYKtz9Pa%TC!l4|k4Vss$Eama&Yqj?A`1s_> zY<^daDur=yED^F?=;5;Mx@tndRKqZy_9J~;Y6#1dXwM%aqfx-4l@YSXiiDZLozh{< zfCUq+B1Ma>D@P5mjMaok#iT0rKNM}^XVXh{cB&?2)fXB$sI@Je3XoKyaT=zmnfG~f zY2gHhV$m@eo$Qu=2n7$n-Q=2(uT*QNrDe)^(#te*Mg4Nm<>%{gr&7SGWXal*$n@8F8nF-dJc+aJOra& zw%&4CkXwv|nUFXh-G20$@{$VKsXyE+`q_^oD<<~2_v*E3o2oWI&$3fdQZLGHIp`72 zagq}fj3#dM6ijU>z80KKq3u%YkVi|>n<1jS%w7qF)bf6qpSsAPfp###vVFSrj57)k z%aI?qOtFz@^>6yctYRxK0+9)B_CY{t2}-Tnh(v&FcTcH zbs8m~Rcp1z6%QWex)c8LNoTE&;Kk=>`I(!c2|3j_d+TK9G$4+sCi`nw_DhQnM)!-X zZ$Z_5-r=+~CP@s}B2DPPi4cq4(2zQ}GVeY@s;>7#-5ohwKn+|Usz1|6cSw93q1>kp z!0H(t9i3qgO#kU(FT1VHU+9#w4oj!SJDSD^#;DZq)J-Ne4YsB`fe^o~(eOScN%bQCX(B(ajHr%Bo8fM6#;wMt+13{a=Mz%>rOyZS-1eX(vZ*!d$0e*%rmpMC)3?+G@ zN-2=7W;svLx-|{QYZsJR=-TbL;p)-Pb(!n`{v>=N+iFANWdn8vRO^}R*AVjnJ-}fx zv!!#4^aN&_nM1D{1xbmt7kCDi+B_+*TrY-T1ItyiWB1(>r;dPZF zsU^GIkM9LTh@TWJ@IKAdYjz*cZg-|H|4683Dz-EXvn-HAWV27N)Ts6X0IO1OH~k2T zfNbPb-&2OG<%Mj6@$af>razP4U1#Zadxz9^V7TyK2&%5d^7YE21hpP>ykx>EGh3>(>TQI^P=HB1xh6S(q9 zWDC@1R^2&3C`d<9qn4~DJt>YPGkO04IpKffu3^BhS6Rxy)*FoeZ;M7h6IPK-JeDw1 zGLA?fCuScM{sA1x?Wfn9?c6V7nLR?h))C4<0K;T~s8WWvKAjd>D2+XM1L>k!_7Sd8 z-23E6lv(rJ_)gBiRCED#UVCFHm+#ORhNO>iM5MNgvQnpf1s`3pp(*ODOO}%R@CuUiC zom21(jtBsgEA8O$1oOP3;OFl&UXrj>|86p*8pqD@&*|XGW9Z%o;H&NAr6S=afL_6F z5lCMxsryu125Nk?zX&Po9&(eUr#8^<#~S9oNY|}3q$Liobdcnbn@{HqTc3^c%=-(L zN!Y*Bm659TRc#`vPQX?jvmf06n^AkX!E|}MvHVX;p}6C;!pr>{LK;lxgj_Crb%k=5 zn4HMQIP!`axmfziG>5*o?Vr_p({N;!N{^*7rS|srO6f|qKroMGZ#OM-(p*kIq=BlW%O9E2#K-hcAkL8U0H(5_xbjLN^UnV ze~ied7jl?-Od?^$9A2{Aj{xY+Sg+f69aFlU=rkKam4@GpMDj8AT+<%qG-8LOECRb{ zFQPj}s;rfqf@U_fG&!UOKu;2O-HAx2t75?lEV7!TkdqP5tPwj$6Fz_5(LAKT3y~Yb z9t(Nb4B(#s5ml2u#Qav@!ON%mWHl_*JHz{Oh7EwPew;P~Iy5%iIKQN7?njbhpdJCJ zOJvqEjX5MzuYZ7JX5*{n-S+rUU2;E{BwicvZFMM%Byq5-w2JdW)VGv|XHXi2OriGfxTC6$-+Geu0@ z?&Y2UXn6V8?Hha_7@y~rX5_H@-Q@Db(%?tH<1&kgC|2pL9tt!&Crl z;8~-g=Xub6lzX{Qxv~N9O}R~5K~UXB&kn{4US!f^_8>dVSBF;lkxoB-wZ zT*v3*@xWgAJ1?=92lvrw%XP|sm=r0+a-HmxA}mE&3cmAFR6pzbLqvGBTW!g7b@*ps z7?VPha&1Y0R;kEtqH!nJQ0LT#`6Vw{VMW6SvJ`bWF^aIT4CoifXSt%X)jE>ZUOl%J z*Nvr*<*m!sO1jFLnqPDVmq8d}D3G7UI_7!EKY@P|_S*UQ*}lJj*v(;Yh9;FJogtHFigp4^#;*fwOUE5)Cr{VHj)oq&SyAUu_o8-R)5lk zlT1`vEanm)_}|Pth5awqj=G<%=7y&-+iDOvSvVN|mnu>gai9jz<|?r%7=akOUJ4jQ z+^QNVHmFiip^CKjI%+jIww_ z`19+BdE1J%eP_Rut+LW#F+K>v{}Xh6{+oK{~+EUM9=!Y|5RYCxSl zn~YAeRa|y~tfncfNCsfb zjU>op6!_qk3ty7HweN9QQGd!@lKK6(9J-V@x?&%}?aJE0M}2|+23OM ze;dJ)zF9$M`EmdBrPHW>^_n9sN!ECU`ep>fYOeSg8;E9}W~Zlj5tTxR&G5bKo6T0Y z?_mrsSGaScO2$2a!H_$a5zq3^vmxMirKO{Tih4HM#P09K2v`9;6C)`ihvb;#a6^E} ztI*h3yQE}3N6bxdK`(5-vtb?$OOjdq>Jy^Y%R^FigX01vNq=xZpcX*ILTQ@TrG;HxjKo`qiH)r^aos1C1gssH(&&ODV;l^W(8Qez zxtq4Xo5K1rWtB_qQLaDdieSSgs0asQTj59XP`fY1L%*5+btuDeDdg_qUsgLdo}+=# zm5vw>;UChx9qmo&qoK1x^xIQhtO!(z@4NJ(mew(MEm)unF2IM~K-_5JUyco9JA5Fx zw^qSVMX>N=yZX@#)IS~kvS|Efb~sr5Mk<3mq43lH3Oz>oXo?T5FH<3ln!JA2x28rr zz>%iz>Q=xKI4gKVq;*1XIAmZP?!#wNArJoGpSuVlVd)Rj^hFU2|` zAkZ`NHQ8QI3To+W%3zW8{AO{jJM8IO??V4{f(Gw=z#PLF@BRZ8D4IC?%oCo&EwzNTHS z3r-5*y7bk_o(8d~<5iSMxML!Qp}AV;eP9;8++Vx}QL06#joFo^C#XD{s>K;mz{LpKU8{Ha1tF$g`Lx`kD{W_Ft-7CDC{j zon3a)^yR(&_5lsIl?s*5w^2qx<7Fyv)lQ>&MUL!>#{hlk2B_gGGcy6Lzs`e+wTdpg z4X>7ug?(}^3MhxDWbTX_dj0MoR;lM2FhvI7`k4UavjI|3(9FyfkHcRu3QE(!{cv!J z_bmW%G4%SxOEnK~wtS}WaWjZcm+HS#XE?C%@?gDHLonIx$M}97_Y)WQZ>?TuDl}t( zD0DqDvxz7Y?9qg;kBqX#arJCkk`#*Kf5S)>%Z`o zastbig&jk_-(!LVrNU3)Tzk@)H4RKLMB9@e^csFPpoLD3*}vmUm7}fat2MqyfD+}W zk2XUofWPTxH$9ouY6KGuB)W3$3xZl?AEL;UDh%goaH)=rjHC$2o%(t6`kAo~9UmWG zNE&5!5bhP5u;Z51E{i-&n5fJht0X&3A0*v>kc?hYb-GLV9uY22qbovf+^l!EB9f_z zkqzBAPu$1?D;oX@QNEef0)h7yC+h^|V6!$4p~C~@ccT2PX<3#*xS@s_o_~ZOW4#3b z)&+Gd>$XvVw?6fII{h_XM%s_zihXTy4PCS&#W=^);Jra6#Cs8B0PKFoVvJwn4QM|l)6 zf@`7#WsI`%mwzkdLM=ANO18k;riBsq6P>?umu6sQv#9APHvqB0V*0*N zo2A^?pX2$#q^&n9Fkw-%72Isr-+^%GCLvDqM~7J=rfz@UzhLKzoCILn#T3P7SrmnG zgMq8ih zxSymj?-#O6=aR-iOK*~P5fX_Wh~dijoUR=3_HqY^7tU@NyChYT0ayf)g1{3wh&Ymx zIFoH$8U_J}$?nwyWx#AIGZ?y|;4+n=^X-1U+++0ZYgljR8o7Mx!RFgbj_1j-2n_P3 z>%)do5QWdx)6toEX&wo#J0ywqR(A#}YMbxzL7H*)!nmD_-y6j!??)r!V|D8$9+%ak z{`;%BQUNpu$3FeGteM$i2a%r#5ITG~ojqY@-rEUy(+F34dc>?wFah}Pa&O?LS_qLs zH&B=G{QNvaOb?e6AVmYMo#A+UjB$JXWgU$om}`*)hIt5-xfq~tINuE8K*%y0|tg7O-t3!u0# zPz}MxxjaBWhGf!wSZse!0@>(xV-4wTI5iv$Bt%g@e~HIZ%yao-=6`mm&N5SJ{JJGh zUYPs{>o@j+#R)G^DlJrRCwVp{ew_}KB(~Tv%7#5f3j50a^@6q?MFZ2xfEIqgFl}dxOqOKg+)l>TIP1iR3_quL`Iik#Pz% zmJ4pD&0?mLhEE?{X=&m9VAZOC7*Ul_f7lZR8K#`@4M~K(56=OKtdVg64;5F$@w(BnORf5C+jCeLo=ReQ-m50l)CZ0t?l!*Q_Y8o_AewgdMO86g^hWg9 z*sDz^I{iPq=K|=Nh*EF^=P4C(Ci&O+*xbI!*eCe`dUzrxgZnD2b@wwAKelbZr2d2$ zRG5K|o0j$8&6eLxPgDiDE6|0_Q=3MB-me94qp_uR7hzX*lnlq}_zG|rA^A@9%^ukV zsmPpmL*UYM9ig;?xyTc4o=%WQo??q-xVSsfYHhVG zw!dY&{J?2)Z;lxAYzolqGZLLPnpf^m79Iv3wm!B8kjv<#--(OK18WBzg>!+5f~x#$NQa(flbF}p?>sz5j@Q89N_6& zW}fr7fWfT9Ccb){v#l)3ls!y$Ue;(o1k2V#HEUezY3RFRDl+2KES5v!#l_qT(V0}d zAFlMsTh51gfdf7WBQ9yIJkLqMF~a3WuLdQW$+^b|o2I`)i3AcI7TUFOctLg)k!qKN zE(#?r_NZu-9q3X?NivG?AmSdNfX`N`smQV^=u7E&7A+kT{6$e?VBV~ITFqxgM z2*h%Hwnu50C?A*pYF3JY@mZoH?7&>gMQk~1o6Lq#F*Fc7vJFAr%|geqY&N!{xIL`D zoE*3)7$OWJ-lX2~rbDe;5MPcRqU4&Z>78cTW%LQh@`mwA9(THSXda;q2 zn-*Rv7~2Ilp^4OxR}W~;J3au4;`w3bPQDDMwmW(f-hh-H*|}Z7qNM8CpZX-6sOV^_ z&uEg#*+bs*-`Zf4k}~tu&xKG&{?>Xn9`i# z-GQ`)#I{~7(_#O?PrTRm=+SZYh?SX)p|X@>JT3riq!} zfJr9Xky{(9@T7%8Gc=|@kNJWK+Uhd>9AzBzqk-m2Q9LxiRAj7>0ML74I&6?30Y^YP z?|D17cgeO}EJFnmHaMNTp3Ixhiy#y5c_mW|#Ye^DO}+U(Z1TKUID(DXM3f-W@Em;M zklshzt(I!_@$IgF^Vm>cuj$_r%18;OD#=C z|A@}@eb#B}$WP)^rox*Rayj*?7L?VCR4ow~5?j9-qI zSI&>r;JJjMvkfbP7*{AbbwlMMqr7fG5NOhBc_MYys;O zk<0RI3DL=3o|im*aI57{Q$;SX8jVhwL0Xw?&8>lD$_eH!ct=;EbFN#+lsW`<@PCIu zEiAQYqOf;1x2q8yL&W~a>(k7aFcRhqGWo`j59olIQ z>oCkMFGv3IaAD$<)Kutf$Lg~0Y-Cpb!!l#iZ8$%s*GlBRY`{B0Z9p7}Q90g={p4J^ zttW@Trldp@VZf@Pmj9LITMQ&5P0f4&IvYUT?{YfdtdQHC1yx^$(MVOXMDlSMg z!a#E_R$=Z%Mn3d7?yeWa9{|J-+_QpNFs3!6kSX3Bn7Yyr=MuQl#$g(>!OFspoEUua zIf|?r;Z!=&JJ&qEmEF*#^1(roR^7a;AQeLEonI;=AC<$))!I+ZWIqJ3;t()?<`Qj0 zMnDA(`?ze4J!;R>5GtN5%7S*3_)n6-`0;xf5T^HJs^|-EK*q5X{-ooft~o#6=90@R)mN1L zv^9Q~M0><0T%)*TCOX?zAa>vf_hF}9*I9JvM0R_M=%B-~-fEYf)pnC^wDD7tzprz; zH>YsX6;Ne_3Q_xs0gk3!hFJaSjiv7u|8ssJ2j3<*2l5AxOEOo30g+X$?P}dlF}B|5 z1S5|7oyT>ZP7ib_wM^A3#W5L=Co%V$XfQu)f}H4JD(twR$09MxhQZ!!rpgreIW6Dp zuc0B_fWF^sR$SFQG$2TFqfB-c`2*oVxh@vYzPf@_b!kI%c~tC9D{0UP3W93lb{^-@ z&PbHBFoqgGht=~um1`8gIC)FgSze{CBLofeiAtsQ-0*Ds!$1Lg%tG5_ekJ=x7A!0k z1d2#tVyKYNuoLyGVhj{~JU=mULY#?6GpXUXTg9ob%4WuK zgqp3y9K{7WQwR-z1j53L(R0&;v2*_N>@pwv!Q30AGP?~oPW;*+D3(A0!V95Rey$>+ z=0(V|d9I2PF+umcl7KWN=CwkJu&Z-~J*sn_&X)~$f0ES4YVv(6BJD+{*$+`+e4;u~ z!_2zsXRFAYW+eTQi%?N<3_0=E^t!V_^eiE!nV^_4Fn1XpQ~9lCur8~9QizQBXkr_~ zOVgZ${?IAfAmJZF^~>*7eHHn}kLLf@H`>JmUbLSz{&<~Mvl2sRnB1I2&+t3X=(^7( zaDfuReG}kyp@r1mR+zG4D5B+Yv&J1faqiN&NeJ?*ZST%x8Cdk6*wS}!{c?K{;8v;T@|jHC*~RPO06=kU|tq_k)`a{F|Qj* zWuhC%3R3{%gy#3;E{y*z+2h5S$r&d?AH?D?{ys>f%%7NhQTBcNq&k#Lci`F>EgYHX zUEt$JV6h8!1x}IZj!~x%<%zEJI|24JdOP#yJKJ5E-%d{yl%m)^HdH>bpJp4jTgcO< zL`3dZzKB}{HH;$@332J-3iFMPghncenILHG%`WLE7SKyT3yHaeRY*BPh50m}F8k~m zU_n?>>g-5VC9)4fXw_C})mSZfy6XH~+c#Y9tFv0F0GibEi;2~&ls$Yc|NT&L#RKd$ z@1K;j;tqj5*9G;>37bKaY3HP}obvcy&ci44$ zmW?Rw;wn;EDNG;uu7a@j=frYI!k5j2NTGuog*tWnMKm>2i;AhT6U#n6HS*oHqMCG3NA0b(M({+bSpzDlr%_+ySq4yAqm)Ta)H zcKN5b2?Ja&-|j~SpKDNj#iG`h1|V=@Na1zw)rk|GK^E#>sg=oaj zqo%@SxJcMo7)ue~T5Zr@$=0}3Q|I=BkfVJmLTG=ru{bv%Un-$4$I{ouhW@=$k=6XMcdp9@hNLrq8^<#B{45T2^2yEMyq2z^;OP@R8i8M8E`Ij%W`@fv7Zv(gIFjGA z+h+tmUV(eowasdW+Ly(s2o}`E*a~KldjC9%1&MT|%GYGtuH@v_uk~f88c4HX*(HLy z9j?)lads;$wj=V{ClutpSd#e*wz{Rx_>v059TovXU*+D^T-cCk>7<46J?Lp?lP`k# z@9RKQC5#IqQV}Yv6&b?zR4I3g9>TXO7|>D9#2q`q^viiX8-EpXEXK z(?-=!yTmNviEr5v6>X0pv0vsJkq%B=E&&5KGevQB%4Oc!TC?-ogLc=@omLd?O7#~x zbM9=-DM#tXi^L4vIe073 z%=vo#^#mhsRcolIxZ+rDmQpl&bH&rM;tlS@o0HXatzpH8i;UH#cu+YAwqKRCZPwdu zH(E~-Ivz-bWaOnvXo{wnS*j22v)BDsvE5#kih0Bn&f7X1z1c6;I@wfKh+O{E36br6 zRrzRo*3xcduU$n$MZNt}A?z#UXE;Tf@;KM%yK?_^kQbkiVyL)l>O?hpyqw8U;@=}H zJNSnjYF@I*!^MWBok-z6kOY|xuhnQa36_k{Y_$kD(IJ71VYQrP6G41h-(%=#8R|(w z-oTP^cx0^ebgiuJrs4YCjArx#w9Z8D6_jD_#tCM-g*%x_KI*fruPQuTT!_oWw+7%>Gf>t>7G$`1H?W`AoSnf#LS!S=~Mx zpRROz*l+8nb>19xbZmeu&W?w(a`+f&E}2rX<^qSW+;<^0eM?4pZS56^L)^SI+$mL+ zzw06jsD9+<{RW4FE}#pg^{SEiY^B8Paf+-Y6=BoS+ObjhiX6HYJYq~+rPf=;C}iSU zWWP{Dkb-hJl?kN!uFOHnmzXGi8(ro#Yzoe`f)vbfjwggVf}arT-UKPCcRC$!Ra8dB z9#9!=1kTSZHx%%Cb(~JJ+qX8bwq3S`;D{C<^1} zA(o}}aQhn?ewt&T`??UBA2Ws$6C25v_zn*be|Ilr zr!ly{9&!MIqmeiVmU6%PeZgIPC+RfQ?>-VywUV!F=Sm*pFinW)rf+DBkj1K3BSKPL zzw0Vij)5UBEjdc<5~R*0g=uapbApsGSJ59ZfMfVk&=upx2eZIujyyT3DIb`2SS>_S zEf=qp%KQEegki7I2P{v>n?5&F)8T#6GfN#rZm-UfU=hl1=&C*skQKY?lWB`MME=74 zl4w73B-9DAsg%`ZWt;rNbPL(Ila zQc@t)ZYmnu7(Bq$5>!V8bUl}b1*2XoGpZPda1kmULKv(^{`;8r{DA#C!BI>;1CV_G z0ePOS-2jL{V!7+g+MdlG|Bnhxx*B?gqjzQZx&BJNv1c+5k^Aeza13x=t6{Vr$u zwS{X)sqFyfaS@57&wE4mLo{%_XQzOC6<2h8(a zm=Qq{UPc(udnnw=4T-Pw!2-#If2eIRAsYDNFKEYJ)%mn(L**%IkSj5zY2y(b6HJMX z$#0p%rN^aMsZ->#2gwF%96_vA!be3Kfk9vGZEL1+MC}nb^=Pz2@)^K`N2;cAUGSp_ zJrLOO9l8++se&bBbugo#S-&(0*+y2O?K6_rC`VmzQ|RF?i+Q#|)E7e&?gr9B7r-cE zGW1juCVO;c<=-4?RFGqwF#X7t???)QjtKh_WrV_~jtcCuTqvcp*@VK(dh&$9fg+fs zL&xKnI@!tbCLH;vsXUnk$1ykYvRYjpS1Jm8LmLP^4ez5hnm7NQg#Z1HL^ObPEDZ5W zLp|}=ueJc#G@kE%H5}-8Q1f`Nr+&?y&fzoKAFbaTX}z57d+IGYUCu@O#pOzKq-lbi zr-BskOi0J4o~Ti{sIL^hIYllh4Re&i7N3GiM0edxs@fnuC>FU;kz0_OM}LCn;@NykEm6u8!RNScqj6pSSdpzFS2Z)7~XGefz?r& z41>xiy{u>v0hzIM?)A5fCOw-MdTuO_H_*-je#1e{eJtIH5Z;-%vqOn0K4L;pmhAI7EF~al3<#biXtgJ6{U+BqCZ9}WKZ$8>^kA>V57x$y)QUL z8~;U3lT%sjn(7vNKV2R4&xve2g)T+=&uk0KjHE^YV&DKM{#7}z{GV@+adB~ll_0GCp~nBtR7qhiX5c*EFWaKW)0_t4MhjfV zozg>N)9RQFmVf}2(a9=>TY8i$sWikzedD z|D%G`>j|KB>V0-}8cfCjSOZ+Q;vyiy@;l1$wo_J4oDIw|pUJ6~o$xeS9(w{}dZ#nk zWR4fHfkn*U#YtG~2#C|Z0W6-})%t)5MkEXYl^xn-+bBC*v_ly5kA7;qFW7Y&@D#&m z?`xb60=qUqE*b*vsu-lr>ygMxX!6a}oFg%~3*ElnJ3zq54&b~3456QVmnM<$IUNQg zQ2&g7J^~zC{2#B*O4K#iB{)FH<20B5JK%9W939DWXu#rJ6Qkhn0XW`(fW$~wDI{q( z3?M9EvO_x<@$v9zQ?cZ5IvVBj`?`?Ad(y}&rDtSFuRhUb`Jg!0qmo9+!X*fO|Na}W zx7Rq^ZFJmayKj;daevt{`U0{=$;~eVUHlu$4&sHFQ$I&aJkG;Y`-b` zixa%Zfs2O;f6T7kYNOR|g~?$CNQUMq<*Gz;pXWx_TkVfL1FiEU9xgh*g~#D*HT#I! z3B7BXXCEMZjyA<6&}yYu@3CBD1MQuiiSWi@Vq)g_9sU&eR>=N(%dC(LxK8%84iU~h zRU#KeZ(o|N+vm=K)cNStuOcgO=1NMa@dZ5YXP{V{XsocRn-eyi+X2C}g6YEeZ0u_0 zN{nmr{uG@~^=kgN=b=!Q3Yl9vzK-kJ5$k2Y8TpkOpfv^iBqFjlRK{tLzTV=?6Hi_} zKOfP~jJdl5Olc1$`;8LBnTeA%7;QqVY%6R=tpIkCsDu}P@=BFfqvZnS z0lcoxG!P)eMBoZLe+z)w5`}1DjWCq+GR8m-(C8BiM?Z*;bcP8kIInG z%vduI+`S`!ww}U;^V3}S>$OaNSzjCh54Zb8Rd5WzBq_<^D2wWgT<c;~Fd zD(Gi>oQG0IC8gfiw*PfEPFFkHZz`txYd;LxM3EVSrTghH6L6X?yRr=R_vyHBCUyp< z>HFcrV3bA`?V9Xx>TiUt0^3gj3pI0Hk{i)-*qsh=`9xo!cSr6@fv;;a08PinX2iZMqgbz*k;g zo;yqVH7FWiK0s!O#!tB{`CH8v+^xOCZmF~V3UH=^EAcx8s~MMa z@19$=9)bcs+b-uoPN;g*rIX(VBDjKe@IlX?UlB`%ZA`gjn{C%qUyOJ5`KDn|+geVF zLfA_%1TI}5gVq1DL$GR5y=0W_h3C;QNC5l#ZHk_pK*+X?|MGI2_~Mb`72ZQjU(R3tcY-t?}1zVO4e#V{C}m-j-zAt0%S zN}xSs!4&s%JOTo6jB|DfySu(8QWL^W<}!44RhzojFbmaxSWYSXR#yMLaGKsKds0#7 z95~koS5}dpw1C$aw^+G;4~8~K7{009ZFMloYkRYBNi33#TbK%thmC?HKnY>3xtZos zI8rhuI0>2uZcG0dKmPg3&H@BwbCG$L@x0d&3H4AAf{A@_y;p#Ud3_A-B)^oIRHhq~ z%>|Mh!x>;3uyATM#1+6(MWjzhZUfsSQX|tSm^{ICK$3j7$rD^Kf=e$3>1Y4#Kp67R zva2{CqGF$#*nizvbp-GCP*K~-Y(7vXIm#;gq#8X~!T6mL=++Q@j-u6TG*( zp*sXzjN>!&uNBn}Qi&9;_8Z(s3ku=!cx-NeINzwWFHvx3Qcv2IA|ZY0s|*ZE-7M{5 z=h6IxF<|FQLqzU&siwNB=KpLY)K3xPipSqre9&O?$Sa2rQBb_XzZ|?m?bQ+z5q+76 zb+nP4wD)_nw)?I)inzW!rM>yG{l?`ia5-Q7&tex_moL$%DAhRY+fBeDzIhNBto4fN zMID?MeXY3Lc$$k9Zn{lUtbnXy)fCD zaD((N8;Azhy-`2Enf;(#Yx;RvDE`g0Ka>JkZRDK@gelwe{(r%q@w~4ww z-6tu~eYyYjc3>Ca8Uf8YOmxv1kH`QYXA<~&+YPv5eP14&*Jo;T{%p{6$_DZhj*YD~ zd%Nq|**PUqZeydCvAsS(WP>??6ePFAs(sU94yRB8w?GP)nt=)UHyH;8?H<0@-9({#J<+vvPcCZMxj zD>%5qNhvMlE`-yWFk^)0rWmARwcbQ1Ln2hR3v6k>BD2jeYGUQ!h>u^m4Y zh-9!s!QDBgO7I+6T*PJbI0N@7yTrjr11d9B6O+>x+b>Cnzl-3QC%K%Oii+Cw-kw9n zLSgRD$y!|BT8xa18XNAN{Z3aZk;}e!*NnED|7Met7suL0-p31}<7e>O>iz{>sInT& zUPlQCs^uZ=>fCyLo9gZdha=%54J8OTo(dnMUH;OW%E&u2U_>O?D<&!Vh^Si zXi`+t0{L@kgMGO3{>&2q)r@_GN{xz$0WH>(eQ||B!b{dy6oe~WYZUlpxH||uX}?~k zJFR>4y`KXNE~%F0K!1m{Vilcfe1`b_N&vlh+NkWGy}A2Y9={l!+ifPsi3&q?UvJN zH3s)!;N2JYE^g230>&Olq_~$vP;CDXaLx{e{!Xt88VttwD!}PWnuVs8=Hq~mg2`Tz{*X0m5D`_E(L96R;Wi1L7A1IO8 zxTf>)cj7!6ZRDo~g|1_9OxeH7W3}B;xf~k~#(PD>3&6C9<=BF>JorwLbC#UJX=pSe zPr!A2E05#7q{-eJs1m#-m(88>_D+4Y5(oH!bpZhP8{K8K$rDjy<;MW>r&^lsrjf0Q zd?lk@%rrf(xpZcJN!{T7zyN`_v&QlyDhRbgqHN3Z@~kU{Y;L9B(Zcsx*eyqu3W>)H zC8+asj6J+u6rt$MHlwrF09|K_Lr)(-Nhmvw58s_`OSpgk4#d)Ja^9Xb&bjUXHH2sL zdF+FEX>51|0|5!PxJ&MP3qO^)q0Jd3zSdipN>3_bM0BwUE0;(Vs_pf zUwB>Swx>#7_DJ#hk0EI>)Hao__44Y$=PYMMc=tu#8XEdE$*H z>LbBxDerFU`J*?opl1YJKv~f-Exu@}Sh^&J$}D%2CSXU2vm6y^wMbEG7Cc3#k$AB< z+#|;NXG94Msl)$A7E}$YR*vKFPU=+V|6%GKqdE=uuO*L7QYceLA|7?4*ZBO=O z+n%h+m~7j&d+&LkcklhIK2&S-wGNaqMhwhk=QrCwo5-5F>D@D(?$^WUaM* z68S}^#cnke%jf-k{yp|?3x+6O4-OH=PkN&rhElc+ThI3Ov3v;|@%Z5x&$cH2X`TO7 zP{?(2vj?bVe59IXnJ@f4!&1|QtGZpQyZEW?rDGfMka}XF0f; z=nr1^`B$zjI117hKmRrS<{xN)J_kMjq>rMO5YJ zx*v}G+4asV^DGmbfl*|%sudYGPZ`aNIG3Dug>q(peYUR#OVgd+g5b!}8MGT<|MtyC zdbL>4TsvkW;&L|_z1;i7e>>oPv1;hD+wwm>rB82)eT43WP7yViAe?cVC1Ud8bhq~7 z6LCA->@m)=S#3|P+t)TNNGw%CA>i}++ql4PGyCJJX=7T2ndFy_A;Bfa=rK2m z(RsOO#QW0(45_6UI%KyXs#Y4BzwVmnhIdWU=VYtZ&bNl=J4$U|@A0YphRE#qVM_82 z#V1}%-_Il?ylDvRQcIKOET&#}e>sTJ@TVkieO?7YqovJ=w88~b|2V*ba+`@Z_0b0{joW%vNK2vGpolv-44XiDzH;JW!vuc zSAiVp(JGgt8M>}GMT;>{|qn*OyD3)L|{=iPc#zgZv z?c}!S2b0MeY*teLZ#!@-2DTdBx+P%Blj6^|>M=;bvF3o#;xOm+dgJrk$@2$rI;bkt zXMUC!?~!<_MNwX+!9mgnYT_vJ`R;k%Yjp(meN*!K`J@Vu%P{D&85!|ocxQBIs6?gB zMRA9Gi*m(gy6auE5IZX?B0QpmS9VI^i)lPX=;Ks7Pb|T`)%xAa2K$fMLw!}%`<}qt z*H1;p(ifC7MfGR5a=uR(T8$FPCyTYcT{qWYdLcv>hI|%`Xo?$+L<; zf96lHUR_h-^bL9ET$N6-XVZDVsOYBUHo9c8RYI{U=g)(*Im;qIqcy(p0e}lDH46c} z%iV=I;U}CgKI3+1aI7UNO$^pTzMcKiI<6L48dB;F7BKT@NP z!P4}Db9{^wpcGZQ-p#QyeBYxu)?`97HQeEUIrjY@QCoE(zyE;jQ1$HoT?X(XWDk{L}zt)V90dO zA7pz*hmd=aNfpw3r!9A#$a;8ve^{+BH8!5cM<0rGQu{{zt*(B85ISC(dTu=PV0HHR zD+Q_VC#2!7@4mN(z<=zyc&^yCC)#>th~7l+!z|lpzS_6z!*#=4q&w<||8TVp)X5C}XKw!u5P_ z&H!3WGxO0rm;|p^A8b5Wme)RRp*Djv^fmWM1Ps9lV0R%Jz*=BEEb*G{D2T*N-R*%sDQ{~fAmf~EGpZS7@1x1JyhXMmv zO{FHta`?;~JcENnI`&0zAeZFbh@k9_W+Ow|?v7?sD97^$EO{Zm9UCTtf<{cK7RR*S znZK@jJ1r8$=V(@NLzKU;m+5$2m9=w3s6i8q`4H8^)-nIaDj()~_mm zYMrni>QFQ$$wti#GCN}t8AF#q$0DzW+c^$Cp;xH=6%huNsT5|!JFm4y9 z_ZlBwEI42&B}T$iu)=nbrEsnFDV#t@#e zKJ>)x;hBoU(FQEmx!>gWhA8l&I@;IPp7v&P8SVeFUWq2;GUS}B5JiZeT6~f7SYGc) zMMB3|`W%H)6X!bLq_+VqOli=%Ipbf3P+xmH9s=B*k0-Byox|6Kpr1$F!{!uS@t)F~ zqMJfR5|^)|dLJX>KjalGUIlaM^!1kE9^=R+E$%H}JV1mS!FstBKr64raiagzG5K^? zK-j$Xa?URSrXm;0Lwfh%oD5dwQq-FmJHPgR=ywmt9JNz!n23T@L_dPI_9f0$X0{~G z8B)0RNGD~uoI`)>$IGn_uhHrj0AI~;;q4}k) zj7x6`2KjWk1K}}J31#TSk&XTSmSr5QP-UdnvOPXN4YhEeUh~@=XDw4{PDTJQ&MGah zQczSp(`?L?GzdXI2CJc}lPw*C4S~gnBa|{w8jsr+;c6ssIbW9*5I9S+WwKtLlul?u zJW#7egTF$B8L~KCEyXDI)389J(`^2+2E1nS$GIxj(6bVw3EeW__f@X0) zMtw!H)E!sCFW?LsbwNnT=M}#!fQ*t|JXtr$+GklstogeuyB7cOeWk~LU%{X7{L4hZ z=;Awufi!jYA5zL!jOmid9k1E~{n z$I|@3RWVki!CvUWg4;6QdCHhEfJ`-!{d1p`#9Y z$ZQf%IN$ZI1(tfNkU6_Hy^-+PAh!ySe)mX9`1`;r3;Lr}3d>_{2Su%NaL3jNXUMMN zUqi5`*v=$kx?u3%ALg@$o3H0R{j2{8@9S5>SMF92jB$f)Qh59z1XKKq4`k{;5lW)q z8}>^Lr(#0mlCxX%aVZL6W+>z!;U$}Lzb?``vPBPDfCAb4^WivqZz&DvSg7 zu8;~XGjjQF83jTK{E%auvS=&fWQkH}rj2F;h<-^VLIm+IpK3#sfPR6su~FQY(h}l$ zuIP%x&Vl^53@iA}nMWnCqC3 zN<{maOVZbU%!FdWyTzj`E9wkFsSG8PU>IrU&&gN>bRa#43i1c_EdwIC%9QElzbp@L z-He%JB?S`Stv(Q_woXJNn%~r0-NL%#Mwn1ez>K@qt@h_Ip1YI$NR-Vi3Nz@Z6aDk7 zR)1S-ePTTC`tX?^8M-Hps;X#6e8IO6{G(B<3it$n)Dci%o}NPEa(Nx1?w#p_<+Pkm zOAiNl0Ro^Q!}P1+NNGj_rE(CxAIs%PMsMhG|6mVvjl_96bLIPdnGz-6_<~Bvbp$38 zq-SAPW<`WBp%)b&?0%tiDU)4UG)jPw=*YcVR63@DkRXIf4Xz%o9kZ~s)XlZy5+SX& zyo!W3W*Ra!_EyN)Hnbm|b(>M1_hRb15DSTDyUUrTAGatJl)E_X{k;Q?yNlARU7cE{ z&xTWfE-u)MSt-inz>4iB(@gJsc{gRNB|W)|v_J9N0k7bC&eFcc*Rtsx2`5sm8FDJJ zn$hBqztB!qD8I-iao=M)FpP|TCDl<^?*-xA0biF69-)SvpVoYOa`Re8_3Y4CX*v|Z>AC}CXc=FmQlI8Wdjm`Ts zxOPpO7X+?Op15;GCw`pYVsSDLwM&upngIrJ@uh_htuO<;w0d)g$=<$Is3=k?A0PRx zy01s8?_1d^cUYBFVkpRzO2uWpR5tZ<>-1?Y(3^ceW!f`FPDZEta+nKCtp{HP|M#JV zHi6>Z@Zb9&&2;Gc2`M4la(5E>A0Y>10_hgNW#+qdF8JIHkQ7T?eQ3OEe*Xua^Y8PY zC@{VGTjkOqQFi1ME3B&Pn#Uf}h~Kibi;pf-^BN}NE`+1s!Fcu;RHd@_#Aa=uC-M96*Y>qB@09r_pMhbL~=Ha-Z5iL<{$+AR)PN6rQC ziuNcCL8|QsnVz~2RXRNuOW)3WokpVj7s?lS*Z{8oniGa9pPl$^wPXU-u-Z!YnwX|4 zfG-^fVoZRbLV(S->+Q3aWH&Fc)Pfe*YCfc;5PfTVLf9{_eU31Kzt8(k92V3q1#sh+ znl9XUrNa3Y%O;9L>AI-W6c)I>c5BMKp|yvg>jIZd30Q@>LQg|XF&KE&4{@qY0d!Y? zm03@qNZ&yuk(;3}?VIFJ6s6Pn^#=PHz(hnIHzCpTcekAvbP-fnTBNa$+1EL*pqW)zfM8prnGRIo}w)nI^9}_Js?t&W9(eLDAZuS1I?bI=P=c8mf`*t?^3dMH8}}N zYS)Ms=Fv--LOYr$4n@M^I0Wfrc#fX>isC?#!Ed#jtz0#k<@vqbD$J@sktkDy9}(;Q zQI101RymVfqhNVEj7n3m4~EutkN-q#FAn~poxX@A$V<CevcPjdgWC_+*^|~M;)%~T(K(t^*c$--j|!X!2KsH$6xhH0__yNFI!s2DPCDx z2jsVdmM?}(z+GT@Fy4-hsa+y8*Xmquqtj}$(Rxc_oG3|&3@5v#;Dyt0^vMgA zL*cUpd@AGb^{=)+5tx1v{U}Y8l`q5#>M<3UqB`x41S^dz3yYGuTrz&Jlx98c7Td%W z_sd*feK{7R&gy=d{S`B7WcdGNlleMZh)hVaM2adRugqo|FR?IR9!DnWf% zEPU^IxlqDHcZ#%G@FUh^&>jM<#T`VGZiLtQor>tRx?Kg>+M_*@p4>4sNQ6h`r$qOu(T3~+Z3ke?hSsV!p5 z6-tC(@xH9q6Z=^m6X-ENTq*$0>!6@nQgfOr>Anh&pU?*jAIjjGnp(u|6K1+ zjx;DHsrlNZb?W4^9Y*PKcl7eM{O@M`y)4Zjr43_lZY~@ORCMCR->-Z>IQFNdduAo> zQr@oy;U%!p!q5`|Hxji{5k8k!v(dr0fE!Wg2OsLzPc$-rk>CxS$><`vi7e~oMv81q z;yUpO-Bl*QoAd&M@)*n5oy+Zfa=OxWk(8_lO;!X|k7B1?8;~MTuAny8u2IhdgZSw1 z9#EX^!R$S84*cl!ZrW|4G-F#N0vOdN(>IZmI>EuVhtvI8FwEc93t6BIBA-V=O^f9v z;m|U6Y{Ko%a%3wO<~s3qXarB99xD7B(FByz`$i^u+=h zCh@f0$$>OTx$DovG@Ic0(^O;Rd^RhsNu2`Sj;y)R9ONQihBf`pavrVB3ps+lz^2FN z?M^L0&-+f_g?(jfX+aj-mfNGrd_~MvQUgdzY%P7-P3|*Tk8EfaJLrSMI0{0JE`2-^ zS=dwDyvS(ttWzws;9KHBOpO1kCr7{{f5H1J7bk@#Z zBWjjjs#L`De3(C~Rz*im^W?RD(A{P%MqB#!QlDobu7D`Gha5qJixY?WJdDRl|6J9r z&sePdEv#%Tf^vO7gA&}RsxJpb6(RVSnis@w!aAq-9i1@z*X5@|Jfz3mB9^5qMiy{D z7NaI>b-NcRvHZ5|A#l1VwmB=}=Lxoy5R~*h&oW&%DhX8D>=Oo*$!}YfW)|WAhI@0@i#x8tTCouQhsqRcGV%;0? z{xBInhy;4Y{LC_xvd`Ffc5st`jUbGjQq(>i1z+m$jg}$1hpX?EkOu7z8#26)XCyE{ zyk4zCW#u=AOol|clU}E>d$W@dlyONCjR!DJ_%SKNOlP;X>V6~92Ekn{>PB8j^j7I9 znxRUgZ)d4moyFr$bAS*#wIEgK^ERrasl1Z!RQ_xv+DUq?m@ z;2oHD*sYsIxdN{Xk69fNV{GBH4c>LejKz_^znrR+t1$xMQWtzhf;z2*Dzm3F*%#eF zXcTZppPRdd3*XI8U^Qf8_{vaDE%eZqjU@XY=!t z;h-+UFgx2Rm=Tdliq_&C3Q*RAqCS?ba)HBGA=t2+hKvq8U}6t*a1&EhUzdF(`S zC_2`}b0WTRq*Q-{k`|^nKA!(hO78l3hshorfi_nbIOvp&6F`P_JYd3g* zb{c5v%Zh`M&e3GOqN}MnU8uz|a2o$rHgdH#FV8T`R1!ydEok`nN}H{oltZ0Gp9}*| zlC{Il_2sYVftULG{M0YrfDw<$zSarJ=PEea$M=E=m`EGeu~Emu915Eu?Xhnn7bDu= zgF|?L+XTe3S|q}@{vl(bjLY-61>T0ypE+)MijH;8xEm$$L}*?nNhVd8H`NK|NFL5s zJjp4=sm1BxT`!gryf29@%+S2(u-9=V&lR#t@|`tEsbGge5KbSH^OC)FcX~sKiguo@oHofyfxrGSI|iXMnA`L+ z`210c$>JX#HvThB9{lMCw(cdYq=D2Vl^5Bl&1d(4t1Pt&VKkCfw(x z6Np+#2lz^d#Xl&@GT*R#1T*DTMJ_P)HGbb*tfZ&vsY}G`7$^v8YrhU4on(9O^v|YmpYd>XJ#^c%( zv+!oVU37NOS5I8GuycwJQ&Ur^VGqA|yp$lPqT;2`lBc4gX))N!Y&zzPA^F>sw*Q2C zw=Z?f?V9))BGAAMfiIBV{3?LuRuNjf48qc zDX}`aok%gTB_q`;D1tV;&$&URLfNVbewR7);;F(-z+5}qulo}YB2GFL6JcA;UZ~lk z5j75Erw-qw17u=&oR?$rH4=MSbqIcJZikak2-rX$>Sdwp8S8r{a1B2@B*OX%wxzMz zBL9Rqhfo}jNUGCF;^Bij0UNm7M!ki%lidc(^p#@q^>2 zgJV!s{S7tC*0JrQU-Ks+qcDA)7}bz~l>%WQkn67-9&xfujRJM zdF9vlt;poWAk9wsRRE39ttvDa&HABGPxEIluk(Y(th=(C>Pc<6lQ*M6`8Fq4Hx#)Q zwwvh#GZ``ar9X|aIj2oazM+Pr5624B^6ZS%rLe?79s$oqhehumprzQp+>DNOP!EU5}Fq@Jb zVXDVfz(fg*y#*d5K90n+9A=6v=DL55cD_VM?f{c&rA7hzpsrSEVud%&K1ggoP3&wR z#*Rl)E_)7{c7tx*V#i~CKqud5shZFvt^F^MwFr$uD01Sy=14)iPLp3Hay7^$|Jtfa zDaS^=t>}E4QveSl&C^2B5pFhsOWt6^OuT|Ra>*=vjCIcZ-3P=)~&fmnc_W7fz2qM%*d-bMhPzCst`sz68{soz@GT; zuJ|QLzxb5I0*BSU*lTS9?mH@O97A9{AA*Qqz-;Iijd~>jq3n+#`p5xA+lYTHrw9I% zO{NBs(Xv_p*`9FSD`)A7 zF%%NgII4SzsA6T45Y$T!nZN~y5g8);tV4q!^7X|SA$3O zR0@P?h(2z!y}W&*EkLkW6lM%E#S;LkmKqIZJ#@T0T Rovo9N2t*%y?Y zl4?-CFsF=3>BvXTwMGWhk|Gw!#Y=y;oIl-odnfc+&uQb1TJA9eSSQE@U}>x6W^yj8 zfte?1T&PIgnf4TXMhrEHfm0DV;M@HQd_hH)yKYsjpHhoEbmzInPrfKHRb-8aR;4xB(dv!z~QRN(nqxbA=P|rS$U#KqC?W$A6A+c zG#-y(0h}-(bfh0a*%S=TmqP5z<;7Szi}mB&`Wz}570#~mhe@Fv9)q+p7du+$fhS*N#T6u@wsl?M&6|?4aqJQPu2J z``ncSCSCai_3b~hbB|X$sZj64uJL;8#v}jta99!Yu=Fn@q?PE?+J(=cTqe0TmGG)@ zw&OTgq#(Kbb?#oaY%Gstad+AOI83k<;P4UcK)KcJaQpYu{P&A?AlQjSH~mC>P!MEz z(s#WH!MJcROw$%h{HSo|ahK>&+*`GuVWpCvgWtw8btrl+k2Qqb8*YfuFB}LX7>M3F zRrZvlT88I**CT)K4RqTdJMt;(TUlV^OW{$iP&1JsO)g9Geo}HV*Vg8K`z{Na=l0_! z1-^iu&LioLjOx>1UT&xm!Yjfk1|>`5mx9YeM|RBc>saW$@$Gu5gX}6uQzI-3c2u2Y zBvw@g2NgVt7DLr6ynw1cfui2L?+b?oL4}t=-#jatEY7b+QD0u2Ht6t6#e1W@(Keg3 zJHP#F%GV491p&wu4n;;qp_vv(GoK6Z+bQv-J7SUV|K2W~?r1qI2cDsH0dT2(TSyCoX4$9}sWWh5FTy=dO~ zwPK^RmfeqWK$ej}I`%98WD2j{^ydqW>9&QYV5EvcTqH*B@n*!xKj~AEjHT_NO%TYD z*Qjadk$f{UKmc*J&pGc4UVx#(Uv8*O%qjDd0Gl5(T%s|EDWZY+G#bSq1w+J)45`2x zF5KBfL2T+P5mse)KR98n9AT3+JGeHG6jlaPB=j`uQ0VbZIcG-69^|!nacHx#j!0A> zyvH}?8)GOe>>1{6vtrD8e_c|lK!J=Zr{ctfZF*RK#=Uwzmn8JOCagm7lyWW!Y3Y37 zUrJ%YWSLA)g-YS!;Z(zdUwPEYYd2%{JSeCd=#78{@3~>y;dG(&%l#RxlFL8nrX4eB zH*&INb%#s0-zVtLrB?nDibF|XV|ZRqS6gOC<0*GA)ieLMdKH5_?5+7RMS7k9h z{NZ|LpONP9acC=}!*bs7%H>~2^0q=sm=CKW=I)xAIm3P)(mHF}D`rSxifUJa3(GGF za6?QX1rLxY4wNMkQfd*ukxUE6=>CKcmVLT#dfe5(oFRlL=KuCx2aEk6F^`ynhj1U% zljtDa4sSe_{)u8d30iTl054Y{Nzgu=Dtr;CZ&NC{(M|p8J{R6)X|A|Q8hwxspE|cD za0iVTg_fE+EaO-J?Ni+8S{sKxMc9MC8{OK!t+|NkW&i?1tfKfKIaHw3O~WbX+tk$5 zuEexM__?!C!@v3He?I&Zus_)Zw&3nprF=KzHoxqU(IvmjEUjmMJgv8Fd2{iYf^Iy& zVm(Yuq7;aW_y2Yh(>Le4uWXT|Q%ukF<(#;rOr8g&zvN29dQs&>C`J%P)RBs3T;->z znmrp{I7?LqDHTL*-s9fnce+0=dC(wVp(!g)-0-28evH{UyMK%~75Tqq)GY%6&vw@V zWPHtOb_{|~7sf6Q@N!tH#qHOVrHKf_#MON;-kZlUCDUBx?8wqdZVZads3#t*vnkY+ z>Ejnn<%{}uQGrx&_WD~xs1SDuhC7`!hZbA8bgcYPFnbOE2TnN`33dkSj8q`Zy-`3M z-|ocz|8(YR&`Lm%i-f=+tSB)NvyLFsW3kO{m?C+$LL=eA#0Sl5-~&7Y9|+!|$fRMy ze=g-G<-Yq)Eoj>oOycynY9b?~nR~Z-+zOw&!&x!0d3ry?S<%UHd`X^4kc|_!tcJl> zF5~N_Yhxa1{(Gy`QHVZl-@kNpc#^IMEoTkrl_@ zAdY$_ym0=IpC<;*9++L_G@_UxlWc7RJ*aA-uXCwU!o_NdFQen*^GBbfXY&mPb%Dp3 zp3P8DFWk?eL|~J+_;2C7xB&qYjID(G|F05%J!Qy}y6mhhkm%6Zm<+t2A~7=)ll|^c z97rVo>uflz#cd86b zc+^2bvdpnub}WyXN5wyM_SBc!a(riP}9KFUc0BvFG9sbG~l>NNzjJ!$S0q9 zKhy3z`7G4VWW?XeSDl~)EC&;95; zgq!6ppTlAdNGF}qXK@wO*#-lStJ9WI3&AV7MXr>(zY>n*ScvVh<__LrwZ7++A--*0 z<84zOO*UWl4Nv}+Bo|I(qYJSa+Em`}Yh|*;hvu&IKMZ+{asGDf9EI@CA2q7vQH1ZR5?RGyM@yu}-Q6X6ws~7cYEy#^_CNU* z@woWxy^L(Rk%x`w&x1)}I5k|_?l3L-n>#SaB9Qr~GuHIvlV}Y*gh_B0I!9q&Z);EK zO+=z%$qSj13fUJ$8WHX(XZ%(LHG9e_mE45C>D$WNh^0Nf20x`;%C7AOhAU3P?kO%stqj!=x%y3L`hvf+xhaWs!K6RJ#n5(vW>X=KFtY?Arm>(E^D^KH1SP3?U$G)u z*(vabepPAQ9&R&#-Wx%6+FD40?WzgOkr@^+kL?wJ5EhoF-a@%uwb&XG$o>7l8svW- zwLmE4V;L5o+vx|u9U&K`!DObP2`^$zq}&$O69-TjU*Je_f4spUUE_bxSBq|q#gi6qE@d%~NRG?wo-=}TOmRTHV-+i0uJ zbZag}#eA1$%#ArkE6~MleCD$v|B8bykjn8K+1a^iLYxW=mQ@s2}rI(5FQt&K1x^;cwM(5G``b31!HUK&V7eR&c-LQ|@O-g-1DQ*ALOrF#^UiB#H6 zLb+r#HnCo6j++27UaS_2@`7;D3~*-j4H(M_G=(x1$EzB@kRJJ>%6t-3p!0e)uZ z^BoKzJQ-TPI{`GpOm36J#k))O2kbVh>WYdy`kMk3#5s7=dsrWrTfHZMQ&k)S260l$ zSrh6nj1*isn9Xv#$wk&0AQ!V151~gJa^tjImkHU*9Npb$Txs6&!{`M}3nA zML^e~vp9K!GW|dXacEjAuIIZhCjez6^!_Z#tH32w)hQZhI<+|$d0{?$K$>~`_f!GM zntPg!?w9$0x3>Smmji<|=V-9IFXZ@3zG7HHE=SYB!C{MOAz0ym7~(%&ce5mziAuRLF0KPs}(D z-|+w<^OK6__Ih5O4tr%uw7^laVF8)W6e22ULXVT*b3EfTBclP733LPj8F`oU%)XXQ zSJbK!$mJd=m;f_{U)uK({z_^G_?&K=QKd>Sw?vz1(LAp2tHs$qLV`}3nrTwv$E}kg zO`iX3(RT}>D0us6(aPS~kiu$@+^5*(Y7>ucrj^T$QBLHoW)m5iT$Z1`%?{M7oeCCI zh?BT>c4Au4xLuBSWyHA~y|l5KB*hujn=dbW0t@^yk0j7Az-i%s;h%@Hn3A*oa0B{^ zak_H+9=C7iGZJJjQqnha;@oeTcCk&iD+3n-4{MIq1!RRPuzeY!v=!fy8MIoxmg=67 zj-`R?g0pG}q-Cd!kC_YE_e*KLSTZP+8E?MIbFWeDrDa^0k!r2;q zhTFdY?AVJ@Gud{bN+*BRm$mLlqPkh*|80Z)*B7D22~G#u0vMkhuumm+PQ~7y?=ISJ zK#J&z0*VpJZB|*0N@Gcm*SnI9#>VCImA>yMdNDr>ZpW}{!e!pxz70i~7)bU^CA^HA zW;r!#$t!oRfFblk3OWW*`GT^!It-t#l_=%de=OPofL`Tet{L3c18!umVsna$8%ZZc zT5&eZj@CP>((i7Uo4~^g8U?pP_-Q5wmyhbz5u2exvo`)VsGRaWCCO5@LaF$p+rPB* z?0CMTf<``<>kZK`I3JNZ2Nf3BK*257Hj+q`Ok?gE~Q;{dlk5#yHw@n z*dE~<#(j$UQqL>Ho+fQdGR z&iJFfv!tki*ZAm!8IIAK3MZ^m6w{}%k#TPCV`U%RET!a&LbTqHByX~omwjb4P#`pt z7%~xIa;!j}Qi`s~A`VS|kml8F>~A?)>*<18?8F|msU4{}XX~}HJLjDN5_xrbH=x>C zacWhSPeL?By|8^doS@ONhhQlo;f$(Q#IRZ5SXW2o2D7?UP?fn&Yn-{0hz+7xMO`Bn z0qLHXDivQ}J9Gp&^S@=napEMB_);;NZQh(9Lug?)pNiPlH2gpHzyJKhfr_yj1Gq5> zb#9(~geXbs&&Axw-udm~A%t3GX2yk)CbGB2Q6KGc7YSAv;@K35b&ao8-b*tDb4@x~2@LCnd@5AMTba0oaxb z4VK@{1-qw9^_Uj5qG$^dw@lT2i=-CS77wFS((tT?;A8Qf&hF;k%*LgHYtv-LoLZhu>Ka?HLB6Z)5Fb`nR6k9yuP2vsq8GwwTmarDvum$N3~)er@{E zvy9LO&(AfRqsT|=#ivoSH58X0427-5Y8yApDH*~b%hlqJ!|VKvInE;Vc%6)fF!bbj zAf)y?$FJp;b#4~J*!XM|*(Q9fVk~|e-6ytNxEV4Vd?F%Pls-XAUQllLWVTeR&4u13 zd^>O3vwxL6+2ss&s6$xhj1b3{Ph~nU zuTNv!@uUOfMp@N#kjo9$qYs3wBlH6#HR3UZdm!_f)+i>rqyq^Wt&H$IyTr(MYG&^G zM{69w9|DvXy%xF17agd)<^LKXZK^c=xMO*#rzOJV~xl4aVlDCNE z8O)|I)p#I_eKW@Gy6=ltF;*$G;|A2;(+hWR#p zR;-CYA3u)jD2#o5K3+C}Bl_o;jSoKCYZkL{FJS6YVo6I)6JCkCQ0`i7>-(}|o;j^r zt$o+j_2E61G6l%6o(=EHm;l1&4$+$gm#~ZJeR1~t!^ev+;@5WXYb_w~IIjhRf_r^{ ztxKH#RGHmr;Ul? z*X`c`_nTsw%s)CpE_6kvGRfnal>FRWEh?jGy532qO_#MpWm z|A)`#6lYD2-{n+Gj%TS@?gHUpf(7lp!*!$Mt*R{xNEy%p7M}NN#*W1K7M858m+|if z3C|A~gN<5OPw>;fx}N8%)zRZLP05GZ9Ji2(xchhsu`p)J#ZjPf%7OPO5oDCn==@*) zx2qT+ZSoh4ESS|{wuCadt^nbVQS}Ql3 z07PWA6Ir|$GU;MfvAZ19Jlk^YQd_nwokk;;)%1E-TuH@VO{~`Qr=6bXEpO4^hoVSx zHjpE;|HTXvH0tZ>9D&WXgdDJSwsw-IcaTDV0U8~jQK*u$%av$Jd@Oc6YOI-H;!opov zttN9}Q|s-=D-@o`oa5lt}jt~Ye~er0ufzx<9$=ivVst1m63?h9>1_I<9;K3i@quPKS8br_{@;<>); z9UdMQ?eCAw1HhM#EB9yQOT%EGG}a>^L?L5iZ{8@y`x|J~ueRIo0Hwg@5!ekL^XWiX z%BJoV)zMpO~-msctzebFS1=X+#fdmAt>0Sw03Mx=3noxx&dHGhFkBbulLhf z>I@*t%E3%Td!N{N0zKHQHS0Z6gj5sY1Aeh#z#!zfBb$qRlFni6v6rM4f{5z@_}NPv zF4EX6bv@4-PX=QSLcEWveYctzIvW8mAT;Ltx7xR*;%r>jAE-yBK9;F~{;%Errdfuy z5O}VN7Igor8O_xLH#oM(C;|mxO$I|ue`2!cX4dEFYBA)Qs6E)^X!`hl@V^u9|NGT~ z9pq3Tz)6Z0l7OX=9pWk5B`X)G8{p>xcQ6?6P)RyR0$OU304fw)GT7;SdN29oZDaO| zn)MF!qdgF}_}wKCw!G=@NQdL`eAMxfyu$8pc?EKDFRRLi-`+4To$q~5WPiUjsLH%4 z-6`8UW#`@%B~XIsbUHKCESbF<(a}8~&O6M2Wj|7ZrxU9v(PGyJO4rL;nlsSa<^{Gp zS#1yJ>1-!P>#f5?S0li#da0w$6~mE2h9>BEu;bH|}cBVH2rq}6ol%?}Me~G2Q zcEe_^Ekey0PUE-j>7qTbrOf>qCSnFqUL71A?oU_L6ux7{7xIKtKuJCXy>dA_^t8Lb zJe(778%t;=GvIO9Xx8ZS?SXhb??@9G_1^kq6tBt5%MU4_l|MIt)e=xczfeg<*WHS(w4LTo_ONp6qq<@P4NkHk?>$ z>-Bt8?Co*)g~@>&F0EhsSFG)^YV9Db*j#K0QJGbgk?e48AWTB^?zR0y)k;mSa53m$ zJj-&jrsq7fjooT>vB}q$fx{$5IW(%vXLSgTOi~-&xEmPhgVu2UV2}y&zS-R#G66R= z9agK=({;f>RGZV$e2!pA!7L}c^{T|cWtQi;!8bbX%O%g5A_aRf#HT9a3*pxAlo^Em z$Amo+Ymxttt+xt`tKGJB1Azp04+M9&U%bV z&o$RtXFps(UGY&h7^C;rTYsB39l%ak)%9e(xq%$dm$}>l_ni{0=cWkZ$Z^#6Qre~$jlnXXdyZq_Q zhlG^h=1)40?dvviO$}$3`y;J|_jLc?-d^YtRLQHjaKbM}aKln(;yAk$Lc_fND&n`tE9wTtXJ$bz12qvZA5k}C0IZ&p!zV85US2oVnVve6 ziRVFq*BB7WBe)KF^;9!=rK?6bq2icfCmUUv&2VVOMJ3H4}o*!Fyi4#I*=;bX9q-H zpDw%5FfpjD>T5cghjQC&h^hQ|Lk>fz!KOf@+_eDtZ#oOIuoN9mb2JO^o8gU z2YYmutox2+9YXLs?w`9P>Y?`xJH} zH`X_@0vUSs%FhCu?s3{O`QKa~Wjb2Hulqi)MZH${U6FAF%92iJ!c-|4ig{v+#sjN5 zCQpqvv$s7TiJLbEKGAMltkw?$|NkeeH6lI6yGBitW`9r5J%i^k9@uJnU?u^95g_=> z0+%60hWp*OI|c84u`#j@skw;oqL9O^w_V+2U!awir*e;1%Tr21j+aPcxqA*4ctuYs z*^1k8K3(a!=tVT*7>Fc3=PgT>VvQ5u!5@qukhLdMo-1icpRP9-_Jqc-AKm+<*36cz z`P)9IO!38bdJ7<-;B_vt}ibs68h9t*mjch}5kg%i)$ zAJz8jIcB25ahUYblrm%k4YjpB-@M;@KaA=!Xg12U_x??@?M$Hq)eh53=rDsuCNbm% z{vruz#*SGJ%3cNwjAaOPsb}%$Zi(7e*k(Gh@1_H2vy<62R9f#p_kif=aWscNqPPM+ z04C#;H^kh&@b}LlSK+$PewXFh32!>ps{4ri0qo?o6|n@hheHzNCcFRYb}K{S6ta7& zU?BnY|Knam3!N}eQ zS>DdS$g;NJ8K~XeHTj(mo#*c}=A(eCnQ?-4Xe;Z_ccxm$v`+svnMHTK-Oh-L$m1sN zcsI|L9@RXLU(ELh##(}Xl5yV?gjl;0ojw=%?HCll==X?+d3|=fM(5O2t=FiqUH{Vo zU?o4Z139i97bWUsrr#gMjIc=WKV&H@ecFK%iPy->OOYAT^2qO?`eXNg3F*V{+lm#M zZh&ylTIaO1qF$mrI9KOlYhpH2HJ^$-(D~ogF`-&w!rTcwSgoT|whr{8ML zF2FlvDXfWKdyg~Apv|q<-3{U!#q2c52>0_Knw`;5`1ur)D_r_qtV_3#_nyo$FB>ZqbnCgi+UoBt<`=9@mrSb_HZS* zR$Nu4yEL%F<4#Iac5^1bFrRHcm3E9_$((z@{OdQ-DyNh;O#FH$K+hEoO8QbCw!K2s9GDLjEdWvVJa?x!4{d zJV0#(eX@N#l=m5lF&Tzc0w)XtZGRL$>mLcv_*3I%V6|Hv8NM-%dU{e`Ex(s_ME+eu zl@ja7xXWKzF2(SXrjgg}!DvERQ_c5ME}xPrKCE zaT!l-)4D14D=k|)LHa#EBqIv}j{dE^GkurrNUc=jQ?gEBKf)P04g#2#1iO&TigyMs z;5eKKjpK*Rr%H^LQCBeG%oQfORd{r_6eAZ-2+vUpk>EEyl*2s5J^G ztpYb=>=s;crR4s4?Q4h0Ib8Rns8w9@{VpTAm8Whvsd;OKlQETCXS~8to4wDy&nM62 zzam!}q*@c2;ZfvNH~az-Q-YtSBT~w9f>diA#?-$#?XXCru?;KU`E$VRDpE+Ls%abO z^h?sldCm4i&&0{;OKiNmG1K*|33_Y;Qcx6AhTzZ%Z)+rFB}u2TZ^2qBPJU*3(oXnc zwFQzgPDtA^fwvm4%zRvTt{*H93Z!eOVUJbCa0sR9HBfhsfZK1-vg47BMv^_Pcqxi3t1fhth()VaZ5DEU*hf$ZFh6X;T0)5_IxgEK`x^R{Ce1XQBE+? zhZizysYqcnBQV0^FhJ6 z*~*{%0IjeUh7ubhLF(>VtB5*M8tVZAt{e(-Ork&O)r#RcG6I#u33J4$&)o8-r17{_Pj`t_z><~XsaQ}D63+2@Jf&`ISi`9+i* zA06ZVSmI&$s3_%SEav`UbUN+__2#JW4sWAXO8GB$nBdP4%f*;;diw}nlxE>0;dp)O zq5Pr(wFEoGr15NnU8sjJG?4#B?wquY3Z(VKq5o=4v}htrVALyVE}yjtkM77qDNdWK z<|~}tPMgG_iYGt43ydNv^ykMJJNmqu!=VunZkh4_V;pW*KqeR#hAX__NAaV_5~=00 zR18ff+ez$4Qlf7Ci?TZ`E*xf^vGng$4!@Z3-5&CX=u#1{c#G05;_kNSL69WX`My7# z?o3LmvR6{8O)nu|#%G;og$|}>f&s;Oy~%d<{|r)kL>cI{PIAi2HK-4WHB(F$+oa(8^Yv}@ZPx@^@X36&a!cWM0qLk^jSSJVfYE5b9^q9W1PCA5`7ddFd!_B;@tKhYPB z^er^=SFDll{geQx4%kI=M=SX|gE`p`S*|>D~ zG%R`_zQIBnG{nSfg;%j078afyxB8`^0;)hR#bZKm+rXY6OwjoY(Qd*@fLBT?mu(N> z1JGl|w?6?=e36>g1_C6aYIp`M0s))%3(k72o1>7o_kL}dITKm9BGA0W@RJ#C5Hh|1-t;*H4uIIWAY?at;V_S_*aZiEkTl?HHTj0Qp<#jkIm;?w{vvERElbLm^kfg%F=Oq z?P6}R#0Mz;w#QQV(PCm=*O-d0$f5rsRKqh9hws)O)>?d%U(v)O|g4lo^HFIAzh!pc_w?M56xA9 zx-#oS9D>F@7b-yHpAmk-&oF3vQRO)5ZnI#5TMot=_6TYsW&Re*s>aEwaBL}^i5cpU zB_1jqFkOz&;tU$KR&KHXc7K}403R*rui)nz4UOvQbvl;q?0jo;W<^$`Wg|$jnLI_M z30=;5+rnjF*5Ir8qTF@fdH4AP51shTEM-zK4pR@sZmaFOP9Z71Z6^>Q_4Yb$doSmx zW_KJ)Elxee8uw16gTZ#S!SJJHzJp)IHM#hbSqSgouVLdPzCL73R=V*tPV?Sd-IYcw z6DBISQ}O;eoA6FnLcBV`|$dZy_|J~Wm~tXUM|b$ z^&#!eb=FdTZ5-fS=llBUcb~7-1sA->15tPm=Kv8HaTR(RpxAnTcr$EN3qj)LDnqQ@_-ZIcCbr*$(htMbe`5zD zV=w`FUmw1mKCdGzW#1B6Lw;IH z(e@5tYSsoMW%0S3zS=#0ltJKkeB}7#8U72m2Ww92@^a2q>1p6I)(+ZAqxsWA*yP*w zpujB=i{8pY1)vNu{c$fr6rkiZ?mL+7ddYfw06s)_VE7vMj*|NyN7Ke;@LanzC9NV^ zl@KX5*m3T2+#VlzPe$o?&?g`YuEy&lmyn|*PbiesZwH5B&xp`aB3>t&HN%^4Hd*@Q zV%2M3%-JaM#9b`r$tVZEjKdraZIajVr-(waId+TiB>k#Lo?ESxKTjc^eD>@gf#0q& z{V9M|f@|p23Y+D1`^+KS1h|;&N3F*bS>O1Mw>5zb zvD4|gVArB(Ug$+i#z>;{l-e+N0|V=sSkb^hnN!m>u_W-WMUVp*U>h-hGVwIwFYk6o z4^Sk2`}D=?{smanF(H|AVER$Jlb6C?@@TXTo)JV4>SglX@fu>UrRKv;g~ibn-1rOJ znKSSOVmGBpn5p+WYwbt!V z)JBf=ueR6MPbf(iMd&yr3yZzNaxXeQxOBaJ3;)al&Jbl3$lwg`r)*KO+CMF2)G1N> zz%H^I&T08}J3=3j#LuMLncl{xO)!hSVk9LYNm6anvnQMO$L($n9&<5(H(_)Zm344l zpM>;bA6Iur{o@{qtwxj0UIe}&knCAT7KV+qWnfU!MJ3)-c!8vNyQBZaR$H&aT=}D| z5nvG|Tux+hWnFU_`hC5!j;FdgTl)Kg%CG11P_sQyL$VF96?z@w889UaQ3vlkel{4; zrtf_81pwRC1)Sg>RDtU$fSIFtf737UWI#IW|*O#XhPjMZLx^{h~#ojIL6KTY8s8a^oR$i1^ zsNQJrYw{~TYh|U+nO~PIhb3-ov_MtXVvrqUB{)jG^xkV94+a0nEYI}e}r=_4yyH!#+AzzbzsKIXTAIc=yJ@%Y=w=|9mX^Hrz&TEj9%p!1I*GyUt>fT zas!D@C-j4l7VQC>an?l2cxe;^!~F#?)6Ni5WRX0K!eJz)G=f9I(uSaD*J`cu)^YPK-5x}6|6qw$n$UW^S z%Zv8+0T3J0SJN!k^SF!KLr7KeMQsNTd_XH-v5Il$morwDW;tK>%NA=3;PbXRmF-0j z-(qnBCfK&`9FmB=t`+qq;X_I1*QK+J^Tad4u))d!do<$W@pV*|*T=eVsWm+S&pjym zA42(Zz4@pKKiz`yyBOoTzELK8c75B{L#jeod5@>lbs)JozpKTMTQUw&WJd;IJ);~^ znvSRPxagSK5~Bnyfn3%_sieDz(;CbsZs_P3LA7S+{k@)bHJK!iKbMkW^BjtY0#>K4 zib?lJQLrPY(5g0ywbUDFO<&KMeQxL*wkTSL{t3f#vfir%kq%4GkYX7_L*oU?)TX_t z4jIX&eTmd2#f_2{aG#?oWysL2BPHfoA9-%QMz^R))-JcAg|Gtb)72JRtJV7%Y33RN zf0~kqi6%~)w`O>%osGf?QT$P{roR1f*vA%lOI5QrKVkAK1@#If0W<T!OHD~xgo}M9TgY4vHv-B083l}&$($lV z(F6^^2;94-`^60VtTt-nb5r-*!);T^y-usqVXkO699tLf*S}+RcdKBoV{E(i%QexS z6HDbH3hxVk(H<;+Uc*XZECi2!Ba=N}frooc7^_?Ata4Bjir=!)8P%%7C6C7*hiWYhU;;1Whs~!Yk=+gWP^H(JX(scfs&6;CZ#Q!QNhZYE`q8|j7T93?*B2ZfR3k~U zfDb=dCZEX7V<$3DUT4#Fdux}GWC$C>Q0?#zxE<&$r|jOH{GIfw)OP6-XwzoCa3z(xIWVZ=d8Gj_Q7O-b4FS?IjDD6Y^S)PHsH7TQY95;wqi>eD-}J zFAe8dHPjbh@H8SfX>{2L)VxEM#Q9;QMTwh~pCcte`YsDgV_^C^yqtKrH<)6MU`{+@ zXM3@ROxFD0EP$GA?+4wkck%E~Ky;BnmfKHj8JbLk7$Xyt8YskJ@&!R7(_~=HtMpC6 z+HrWX#J{4Uq+jH+`_&6Brbq9}Dt;i+O8mdhx1Eo$e)GVVPBDo!ny_Q#Ag? zgc@)3(oATi{&v^Y^(`~@GE3l%=O*8PhMGD0A^$nAYbSIrU%dgE{Bp-Y8C7}=(}FT) zkW@%bMPn}4xVF<8QNbd;8jC7ptCNat^}IZAFkN{|1Zg;bgT*(&ZcF1>tzvs?JGm+9 z9U>RtzV$4Edvh74MN zHA2cKmK7Mqx;da}jMcBR6~iU!a^e|Hhg(CWmiK$FWroFRO5i#g6iE=LV7Trm+{1rT zOUWgjQR_}l=wLMEIw88_?OT`{heEerO!7Kk(RHXYHNJ;-A$yU?8#WHk?(QzeN58+o#C7cPej5k^ zi2~(}`fJ<4*R{LIb850CJH z9XA|K;j4UaMtN@7V$Sd2I!A|Nh@{c1TO|XZnSPCO(FFK-elv)SVU^ZZ!Vd49=gf#% z&MOySP;EHodimyX4c8@1Dk$Ep5s(VM>=}<%22lUNOIN$=Z#}3~Yo6@N{rDIWl5{(T=1ngB(7Oh_b11#)36NToe#QRNAw`G>h(8 z<*i}YY{Y|^^(mshFMPwpo^_Xr(uqtf8a0307?{ihn1*P@BJ~NskJ9FCD@49dPJSCa z7yM_4?iy*V7yfK4Fa!s|_;-~9F@bUYyhXabd_C^|ACV5gG5k4@IOEXA z;2q>=f??rOotS%#8UxV>o8>&QZ#O4(IxP}aI>(+bH!}NAvx&1;JrPwYF)B@NG(?(> z>BP1W(^`46mZ6Y+vrp%h6vh%hF{<{Dz7&WAN-7eVFEh~s_##N%}5V4d5rjB&GCuLN00(M@c?-fw3G4ayl6kf>Js?Y)Xkc0W*i zl&y2Blp2U4%bl5_fQ$Nq7Caf6h+FlU*J^Xnd#?)XfBuC&mS)ntI@pfZ6#(-4#<>^H z7F)seFJvyE&QelRJm33?J9C>8j6jKrF(B*b3nGmw57gUXMSpQyXsInEyp0k9Zm;&S zhtb*f=kl$aQK>*(^&931KGe=E$^0YFyDDK&@65k1`2YDmR?HWd2f>sEBWEwX7Uhpz zN|`j#ub=O;;^JT;$al#)yem+H3eNwh?v4V*zt4t_z375iJ9&^Jdp|Sz9!(5zuQo8 z5t$BpZ;-34KJg6>l%XA}el}Ag3z3*SeC0NNhbu_YPuYm4y8B9Lu!qdS*$-`=s>ntC4d9J+Bm6vD;1yjPmyo7 z#?dd>0R=M-wFr8-7nJjGv-?9>hZH=++!MH#{a6T2`Y2Z5sA_l^S=-pN3qp)QI)8_W zaBt7cQ^}~u{bsPo^#P?{sV;AbT)4&!2FwDs0#jxE45A@R5|?O7EIjhN(`?+_;u;bE z*(D(pJD0P$!`W!lmy;p^FTk0MrUyd5n)}C$OwE{7x=ph@icqmC^~hLw%u6)&BL{uz z@3NRBwmgVjK~!cA2r}+wYp&!5{6+*gqfVjkx&)f-7)SpL@bGWYgJ8Jth%lQ_9I|Ax zkJNmk7FHZcnKo()rpi37q*l35Mzh_O74rizfF6s@6VsJ-saiF3v2wDi3$iAdM@&SF z6%LE6T*|hIpW6drY++aq^pzy(?@O1&ZxbNf^JTaDD#YJZ0siwgo;pz^Wg@^izT6_Fy{CI_r9&e24cSGwm7>S2j#R~7Ff&GyDD9aU6hN!n6I8cXH-iA zZ)?l^KgLe)g3JXMvBSolixcDp9AHv_47wX)mTcQ~?}X27rP@4&1pDf6>&FK0x^VX0 z!?9EudT}25OL7V&B_$R@`sNLXKq>;||5l^_c}<58^{d$c;ZuruVxUnWJ2(&_T(kHD zVWdyS>x(w!xHdMYQt^rK1(2B$r$wa6D!Xz#KN?=;LGWEHX`cuW8$}UI7#C5RuVEtq zp)D8ldy_LEqLWUeSW#_d*Vr*3b)leYYIb?FWGVwGm;SS>!9<^hTB`E#d-oK~yMW^Q zn^O8C{)2kVl!fYV4i;|q$qz$REXHjrNxDMaJkMSVY$2i4V;d$nk~6b1cFrJIIf~JZ zHG${IFgcByYG~*J%*5FlbnU~^-vn?2C1r}#%3E1+slEBy9UBl@!hqkIhaFGX=Ic4{ zS8CZ&ECvQ+=bSsP^{4x|it>F-6?YWac!Aa0)T!MYwW(_=>I78rLtd}F@M530S6~l8 z!xZzT`QFulayPOSOtczKToS>=+<|FF>Qqx*uRJ=8js_pX&d)peQK*z$^ACjP5O|YA z>nK;ryT=G{oQ=u9j^wOSGC=;c1rQH;sy_3S1A;ahz*#Jlz&U1BeQe{`J`>cpqO7W_{>uxY1ZOWNiRyAg%;Q3 zn)~e@mz6YF+XY;RkF7N$WmP7;apSE_>9iBPx^cR%;q)2$R5emb@S72BJ+Z?8x2drj z`h#A65yJYYd%juhh;;UITMF2fxC=E2XAjrtWOAit!CcfXT~?pxwJ_2B`@K+r|4_gU z;m_jx0$6&kSerCHl)HfD{RE`3ViB%^IaP>t;y z^ii#iIm*>eEzC%`o+6JfHeVh--lqOg+$M8ZW>31d4mM8-nXLy z@^HsG#F;y+Ol~zL6@|R|Mf?#sZaQ!Nnk4ic%ENxH^b2%M%pj>?Fz3K#P*$r}p zFR1_Klk2o~(*HGb#lk5(95nEKiT?=rhlXY35lDvx*r>xxmLMT3=q){y7v;Dmh33(E zske@klnas^pS_U<=Z|hAj{CU;+xy>e! ztZwWh?hsdxRO`4n7Y6efYAD7+t!}8VX~!fVt0lZQ!_U0mAfbIlOh{_S)y1kp+cbuR zSFL8xhlmK?`O%0`Qf`x*JAJ49kl+&5b9o+x+L;Znf4>JlVy@EaIB~y3 z{-I;wTSg_)Ia8z)qzzTDE!-RiuPnP@vs_Nf<34Am*RPg8;8rYt zLghWFx++K?zS`c4TNTQty6wLv;YRhBwD}%Dn#OEaN0o zRv2l4rM?IUm)nQDi4Xw?Tu-sSFcf{2aa`>PaDE!Fa^@=f_M822+Yg*lQW(pHOA$sF zXj;Ri-@BR86cP_KLH~x2NPYJm83c6mYK|Si2zw#=c-HTQzr;N3AW8352nZ{Z2~Xj*$PaTN_osh**_to60q&wN z-3XMix{FvgS?@G@Hf-NP!bPlNAu|0y3DogdFsywT{NTmMP6t}0HfS@n`4I9-0-;$( z0@>W$QJ8$}gnQ^f3x_Sf zm+K^S%+_RjOruS{ii3ZhGhVqSF;aQ2EFp?vcocq@ESiXpk&b3%zC%t0xzVc1Nh5uF zL+SwPV=JNd<&Cl#NFx=ixU4bwn7QV=N??;0z zqG5;RJ(N7S8MaM&JdD#O0mF3~Fl$GHKE%v4!OZQ3#hMoDG%(`u-yg1>t>`3(K(x~U zrt!zR>iOIlqfd9|>*q4!r8eI4jeNKTM@iJPBH-BFmLd0PMb4_;d)1_MJec0=w@w9v zA#OzO-9LzadaI&A>@!FG;@w4(b1(1^o%=G0GW|7fQJGv)x(}NwtT2A*&|AdT@B8NI zGbDw1>8PAOV&m?{#TwBL2Q)UE@WzSZogOBJzq9q%)YJrcU!4wA47LAv3A&q$5qgJ^ zosBJuh_A3V2A4@&ugL0%EYzO?VW~>DtIASCE`yg7e5BI2mBVq>BTikGk@3Puj~YDi zx0Dz8{+X7fl~G^V4)31qaKy#<(f6J!l+j0+@D>+w=IC00p;?(SgjD1M+MC&@t4bYO z7?CRi0r2Aqfjk5N2|eDat1A|QV!>!2(S48Upmxml;Y+{PN6$-HI8PRL3Y|kX7M_Zc zO=D60$0mZz`{o?Gdk&T7OAr*S5FGj}Pof}Bx)hqha&?*9!*}1`W}HayArrWT0)F(h z^t0>8%-Cy`K9Ai$-|tUmhy>}fE`44?dmYIq53gg!O6Tq+JgsdVqK^ybV5$g9U;!eg zm8(Yc?)gRfq!r8WSFmeJ(c~~k^n2zq=FHKj;2{%1^Ga|bp@+V7t%n=+xXoqY-NZ`2 zWM0*)f*j0WX3@+5px4-=KzT6ihY7^Ss#!T#1716AGNW7L~;aVj-Vfv z3(6wsmBlZb%B0m&ol|}Dh+E@%D@NfeoRC?!RRWNqSy_n``UT-!k6_O%KyC$PB!zN` z?T;KVeV7@{WAxuc4COqt3Z>v?X-%s#xO>svHnkOd1@>NZLW*_WnLV*Oo-Y*@hqg*_ z{my*jztqEw)0S`twUZXMb0$|7>2T|Ks-yWI5|nK);hQCLCmtM%1s)a8HlX%9Yy7vcCw>i0GKV8iM(8fyKT>c z$gN-fs+FA#6&)L!_|p$9l>|_N^uV)%lkjPBtMiXEiK>XiXcA~=`ZPE{W6^(4n#n>1 z?d@3)CFo!jXX;|g;f7JGb*C%L{XT9enH!~jLFK*37MiBu6A2NN^` z9WltL9en||B#62p5>VucxFS%Y)hPbtkbd!I2XE+N6m=u@!nixG{pBO$(LpJ+AG<^s zujJEzaqMmFPXjTKkL#k$HY5n0@O^m8c>>Mxl#+V`iUZ7x451IKFrtzofM;l#L;!Na zW;d_^HJK%FZhT!96UC5K=+!JHX-#A7JpHXOb}9v zL@o1sXrS0(T?xVS^v`qAHGhC3Iax2w3oQ8eZ%@TtAML2-H>50&IP;Y|4oe)EZYy~& z6-^ge-bO$Rg3l`@7sKsct_UDymf7`HFQeCW`vEVi2-JIJ0zD}@_LT#!0FRY`b8AB- zG=$u&ry-pQ7V`2a3CJEdMoYHoKw*pg=?e$%J+_26P9|TXi=F8b$0@hcH^t5*NW(Oq zNF|$?w%-sQovHw*jS3GoSqQvHvNDRP%SN+Q7<``Uy-t9>@DgVLtg@CDcaROXJorv! z9SZ9hK|5)BJhZ*1-APJ}jgFr#jCKWA^E+Fi&h|ZTqu`a2E5m;zKmb=$ATG#=kdm1> zakeOvZ14(hvROl_nmbzrA1?&j$(6DNbeR+-hyiPzfvV(yi{SD4?<547g>>zB`Z7#x zz|}KRnbc2%D}ZO$&RP81`(sZ^S}G*MO%HTlEVd7+GZ-@FH=C~brG&aJ@;4wJ^9xXq zsE+MCHrQKZSyxeXfOdN%_8^e+h>`?OnB=g2x8QS3m499@ZLm#08jU5_`W}|@+Z?)6 zXS^Ct#xn__`g(d;e z<2VKg3;5}Mh@ECF3;e>Al8%7h7e+cp8(ha{r%l-(FrAu}q^_b613dO>jitS>uvkTt z_r?Aud26M+-EQTq6~ML_&(>Rtii_@5tD{w!bwwbe_5HeUr*rMiz_K2C?L8>TZk55i{{77E!B-JWz-G_8m&GqfYyavp|hu^-zzHC)UY;$ zONEPsOJwq?)#)qMk4q3J))oDC$x8hjw5X_P8kem}g??=e);ou4WP&OfE4xD(ak}&5 zLa+I?=Q!kA!n$RJHhra;wY$|>AfAQ(RDtrd(EQ(`&E!ib>w5Q$+ya5|5zOW|Tp77Q z=Y}ZX3(=X;Vhrc7!^{Jm#QY7!@{0@wvTadcDy<+8GxmM!uju4yEW&k#YXQ~jS2k5! z&fNTWftp}XVwi6>?xr2($5;H-i8cs+=eue8STsfXy)Ihz4(Y8Q_@Skw@8zO69?+!d z%Xx@lOlK3jowO9wM{bq{9S7FYnlXy-A~{Y5u#@2MzN6Vv+Kha>p}lW46qYXg$?Qpi z`uXcnrYkV{0q>DbhcUF7Cb!HK zqZ-#k8vsS09QI}aHjbod^T;QZ13I0n3ojIP#-etnAXLASzhK5Sc>t?(SClBaR;jF` zJ(K2CcosAP*$wP0*Uh0_xnyOf>=)$;ERz4;KNn#eZ-N_}prP6O-@3mI3o*gkm>*3x z<6~n+_NcWnK%e`0iJ%Q=^wnpxELRs|$;G{ath*p7@*hp!S1G0mO-G6fKHLG!lffSZ za#3_{ku(brckp}-)F04d!RmIzFwAaiAAfgj=p1g`lO*%%euTVS=YrjQv^3{Vd3XCM zJ0wkyYz{m@I%cL_-$hraphTw=r@$G|Gg9$-2sv82Dkb4zJ-Bwvop36MKR;Cl>9w|6 zbDv+6T6Lwy$7>R%PSq5q^Gt&Zo;frWWLO$VKGL zjQE;_m_)=x@dxRJr|WK}+?9lbqnLofwQaMYc=Uc&U-sZAg2l&;Qm@$VsH$KT@iZW= zkU0ZwJ>bFxzUJa1{_sIS+(B<#u6CgnB%~@4Eenk(2~B5OY7y=`22|%bt$FD8-y{Df z7JK@Wb?2Y_4n_qaE((*vwKzt=m>)oo^@PO$ylr5v*Uv5_bIv8T2 z*dLGcMEb~U_Y>5o{lS4!#o6V_FC)r>r|sR9&Ri zx0cdC*!6+rUJi;Ggv~WyIx7d$e|cV54_!Q!&fQ8ajQF$N7j6>GwatFlZj*C$Qh=HM{V)^jEVc+W%<;f7(jG6o(tx%uVNsA699R z#{m|}HtGZ-nZDN#6aBYc*VJ<9_oKS6N^Q=kisa%_<`i@j)ZhBA-6vGkRLYbgQ*q+5 zq@L+8vg$5=PX|)UXd-a*cE;25F-+OrqB3x-M->0mO{LqVszn(?VcYAn=CA+@%HByn z00-I)3cVMej1Z*A3P!0<`s9}lxAtph#{TGN_R}z(AlbdFPCTgEn;^EjviQf` z!xXqXcU8K2Wst2J(6=_YM5Q!8ZO! zrpn~U`O&2q0s79Sr%|C1fKKutIzA#6{9-7FK>0# z4KX_EFMy$9wosOwe5n+CES~8~CSjgG`R2a+QyI<};%bqo%v|O7_n+gn>Sxtw92GaS z!LCySqc6074rl-F{~Yd}deMgjB`grAKPG>eSQCOQ%y;Q;#$%k1=pOI;@lLa(c?2=l zW{YmBastn}48lbm-h&L4OWcDEX3X&~4(})FMbM^b8Xzy%-U2fn>nerjgK>rUVt zQ5`tof0q~i`+zj)PIhqml3vyW3PBe*#5ecX1tEpt;)8FcX6@JNeM~j0H0YSUvLi{j zTHo17sS2BmX^IZUE}s_{wP?+!L!B)uY^SX9{u`4^uo?~v7uR;N`sX!1pHP$e>|X0} zCHu7mf32`G5ZawgYFrm1W|v+BW2o(Z& z$5%S_Lx?$c+PnTkt|iBvhuoDxWXA6QbxyJW*E#9XwGkl!tu>h?FdLg!!p?~>7FT%lfKB*g?4=UKYt|k{ zn|e_#Es$LKTjRXg)v~nsFGhT~%f%i*$to^RRh(|r5=_L>lYWrO~BD* z`0B>dsI%>w7FGUhL+f96sDCh@KW&X*7DK4W$pJ%(;b1f|HTrOF3@)?I4!Qs}^XvR)V*e zm%*;%p;x_j+H;$4*jwW=oVMek#IFd3ulDr}|HzsjuULlWXPBeq+Qb(0$cVoz1cjx4 zP@FS%(h_VBk>XP`@Au7|Lc})u9dq%1%Qe(ZRC1H=MFFnVHfgm1=tFBKCeqx7R%|J{UcyM=f`vBXfg7`a|vz`?_M`k598i z-B`uZ!*S}j{SmQlf~DRd7c)1{@=#=~Xx}0P#$I01Jy6FE32K|8oYaS2?nYu)$PU;Y zjB=Vi4Bc|wbiGVuuXlGV_?r{!gspZjR~1mT&E1bR2q?5#*w%1xx+TX~VZV!VwxTgH zP-K^)qC*)OnMRmK_&qKce;x$D@EQNuGXD<-0l1R0;1p1XbvSm&06yvajKsu(07FAV zpk^7RE)%Jxp}>yeCutlOfX;Ubtp@sjNPH%Af~n~C!xZ;TlU!LBc9Hx4N7XsF=hbdo zKeig%X{^R+Y#WWu#%P`?uV>9U=9s@xcthOW zijE@QdVD8h-)`6bG>W6y(LVb;Ezt@`>s6xOyJVkf7-2LV2B=5u5J5Hb&h7)-a|+i( z%t4TanlPZQ_j61ny-<4mkd$F*)u<^*I~hnj7*m!`GPyz11f>kPSC&y`s@kgABC9p! zKb%%)y|3UR8H|YC${?_mcSqw~c~uj1*Rk-0yE&aj7t=q~oFqlXX<}0PR0;u*rV^V{8nYY+}-8p?C6?*A;P%BHRxB1Qf{LJyvSdMqcMI zpvhS{q#%H%xQR(|2*uFj*iH)C4cw)%WlC-p88N@_UGl_Q6s9yS#Al`q5@M2b{*gm? zY$%5-Lc1Vv{{PF%!Vpe03R6%;LT0v!y@F&GfbNru3=J4{yUpmI|77ycx_Ev z6&;6wU>(R)KQIxmH=Q_Hs)^MG{6dV9TIEV#^sgrXR1%shYFnLUArA*`tZ#CiLN$48 z40kaTNs?m6f`#1Pd&ZmK^nH&#>31$E9j-I0`(AZ#l^X+Ir{L%?>VqJKlfkl><2f>& zamm0kqyoL(8gt}Z^;C}BH@4(MU7GR*=pN}A1z6JJ)~%$Z>SJI$A7UO*0FR9{Y=OcY z{FO#fUz8JO`y`Nd4qnz3Rum~~cODLm5CYm5d~~zgEx)_(h5Cjs2n~)X=(vdL8z{k| zeH6o~Ea-9hkH>E$I*Xp zk8!2ME$99$B*doW3DK4zN*vaJf|Mr{vw~DD?i^`)35BYKz3v~pL}d9 z?Xk_M7G0(O38z+Y%cmeR%l1(%LqJ*S$KZ87XDQ-6!Ce%(?w;2*n6>boFr``>pLEs% zO3xiK4A;#dtC9hf(XltSRBvHFxk6F>E2gW}->-iyIn1{@8!j6XV(3PBJncVSiY>%N zMSmL?1}(t&10O~dK3R{7C4u=oC-t){cm#lGK?1sSuSR9Aw&WgubMypcb7^QBo(yf_1Y|hG6tmhoNp#lE-viJCCEHC zVW6j-m14ABjb| zW*69s=uCCpA8I#hyWVMZa$1v|vY6)9C$f}PtUc3jK$$!I${HRehyvkT1x(YiwNWuJ zyi7E`tASj(s-_K_`QJSgz_P&Q($sNsB3UsRMUrz7#+Ngov)~7u?dd*3{a7 zEw|$nuRb*&n`)q=Er?;DzGI6Arae>C5}SH<4j0h*pN z1{rnyf%(vw$r+3N7TPc;IQ%atLYl$NsI@oDlrGNcTYN`7iVEB@e$bR&^N3nfv-&3? zD7mW8=muFD)k^FpZDJF+bfg+?2b|y6m`whflHa>&w$gmrQf;I;QTz9J{~v|Hgyeax z&SWeY9*b$BlQld%y!u!|bBukcu|ilkURs1yDOn5&n^nTpQrvQ6zlTvN?3?>PC;Q~6 zadgsLaf1BfdS|9iBxy1!E&vjU_F)ciq}2$qhs zugDKG8jSnnfh7~t`~gD|P2Jw+5|HGa>5#r&=pOTczGBRJzU&nSw&d@VV?q2xGR!A) zY!=_VP;~kt64ZO&q-q@&Mkg}Zx6KeQoeKT8crb{fiUB11cs5IkDtYVE= zultBIl9PaXa$)u9ZlS!f*PrrPa!gg%0DEDN&gLQ{`h%ZYi{)Ypv)dZlpD{_B>Y(AWCcD>zVJZ`w!qS!6ilNTuq9i7Z#!T7)q z8->0^KA)-87z<^#xBs%Yr=XyyKH*N9CWX!0B?Ez$NdenpXJnwp#2^t-=llrTdX41L zQbh0+fvf&rGuKMO=wol7D5twI(A%3}nEIQ_w!;Zh4}$nD3rn)-l%n#Cv;YDd>NzUz z&8y)vWSI?y@>j2-7@~|0){j9gML|KyaEasdRP&N>X=~hlvahm2xn@D4!Va;NttZoh zG;XRdc+|P6hdT!Q?e?wPdoXo%`}^TxI*`4`fs_77N976g&fb}i;ZYTsugk{S{PNC) zQ+(`HKF`O)`DgHS73!m(LHX~+78PkbwO<%_RSdVIQhijT)EMmqZ%GiSh(;tb=AY~= zj|t)zf7@W28|*`_5NlV_;(#kNL@OzN>t)p^(isrZDGeN=v=tfYr%cN(*wCrzL* z0o2&>u&O1hM#(yxe|&_=VFE|~#XIh;s3&Rz|EmaYNs0T}9=A9i67qSaJ!-*`aQ!ZI z89FFSU9Sa3wd$*gk+velEG&-KI|IOyvSKNl>-+;f7S>v4rU5{($}`xEiYX7F`F4V1jyg zjz(69bHN$OFr>+#>h#3GVj+tWa&GL8x^bqx0*us#`8^mi9ts3PDKQPG$^3#7-0A;4 zZH$`)(MYPeAXB)oB^hhYuDyThC$++K_!=pZmAfpmTKUB9SVGRr&lWJHpoRTTn&>*| zvnNv6U$srV(n4D8Bl2UVNfGjpwwtqi=Vx*e-5KczkZ;O%Pj5ys@fksNeuM){a$n`L zY&`EpkGc1qQ94wIW}fAUGBbbcz|)iBXpJVF=dZy-QX+{iRa=)>eeW~p`cF}!mjQFN zlLbsV+O3j<$v>ByZ1n)^n(J9v-(BW-QUCI zN<;?%Nvc|;P(B$dEYBvB^lUc)Mz~a%v$}-Inbll)Lt6Se&L1TAN z)~4mzbTMYF)B8KVO*%jr3=LN5bPuO)dM?~ROb2MuwG|cQY;w44mgxB361eqdN|&hU z0rG#5>U^?}b9p5vMh7R@4vilPlD2c>3pN~=#yWIbQBq_Ifqt8&x{GS;RVF7|(NU?i zP9bC7wCHg=t^leMqc0=1x&#mnKum;(r`$vJSgUFozHb1GxBaJ)fD<-LVc4m}@W)^4O`+_Q>kmnRQoz{cgac1^IdHZcA=s`|wU;v)aW#3B^f`LL=A;Iab?8mvaP^dNkKELnW znXsg_JzR#uZIx`^5N0f6uc#>3Q-+Z3vGccZCAzT!U8&+fS?6|$;(o4^K6_u1eLBYKUpXVJ5U<5dG$on27%C4w6^9{G)HzUJ_Tjw%C@lYnh z(bSkL{ciTo((B2C(B-#(wKbO zogg-m2zc&3?vT@nc)gyk@`OW5i}QDO0CxRE7KiQF!NlvsMHLPaK=6Bo-KAtL)suH3 zxk}v0I-V;@l)~c!0Qh;rs7#A5$ozy%RG(`bnrROm#fZKxjn+6$Q3wvw9H?G zLY<37{La@JA_vfbP!RvYC7~KZ-Xvt4<6|fHiI@AcTIKJ2on|$Z7k2aXXnM?*?(!x|ZEcNy3RXZ^{oJ?(JD^G~18N76JP4GfPU#eca}G6g&*cH50VUeAYZE*JWIZVtBxMnl`HPlJ+@h5_hQ zyuB4u>XkD59^|PwFKoM-6iwKCh7cH$Rii1i7zU!5JnjzH`2!hEHj>sE-(1d|ff;UV z_c4oE3Y~7d&#o94-^0sk2AlOLl}@|)%&f^+Y$1^ex665dKR~+!j4UxGqG=n<3244UsfZ6m%4LS z*`9L@@;R#+477;j4kyO5i-|LC=fxAMzEYThgd8uP$j;~&(L?!tqBT`^+1EM>vinKS zwi6%_j2QX6yvR7o4{V0Qxl>uksiX3{YMjbbD+c(G>exaMvN8`A>vkBpln#XUOp+}t ztidq6-bksMNiT*%s1n#mFRAVFLeKrzb)qaiSa-;Mh%aa!iL}bxFnRE&lhXpx#wt#N z7Bd3@yW&PC6e~`9CfBu;|MgE@26I*Jd$!CxtHom(`}%p-0qv%?lzbWO!B|QPA%qHp zqI$O>Tg%M8JjiBaACYBjs#Ls|?tam6JF%9~vKyW#^Q$eZ+3+_t`$XDC1(gtLX)Z7k z+Pqv4%cz>WlY;pkGxd|-FsDKN?S5@eBB-x6XtjhJbn%O`tx@&U3$JUXlrI~#R)%<>tV*UkLM zoY`bjQj(aMn8ZOudHH&q>*2v(b88#8I2@Xgqol_1!}(vXD*bn`7%M9qKA*c~Tc6hq zAQS9`yH%Mz?vLxW&`tR6P(gnF@~~F9iGs?HKhFiR`e880Sa#%P)^5EQ+&?PI%5BUY z&Ja$I_nmi^YYB3EJT6xorg&b@eh*8ihHKa1MqJ2$+eV)>!SURUIOaWM#@?}7b47vsqm zd!xa-V|#o1SyuN>v&n4t=l#@~M$?J)B8mYmzGveL0Bn$$05K1F7=3o;3S_v_v$5It zLa@iDW?5PozaE?ASZ{VB=Gy5zX4B_d{KJ<%n8=`|p*g|B_2cASGhIcLgfDR1z8?> z^6B~ZaMSxRE8y)-I+fA(Z(~(mU4jZ-oO!2}eMbFf@9h+FYy@hEM3~Qha;iK~^t*^n z$jWAZ(Zzap^CaS*9Arz)_T3IcUVe>ZW+6CV{?lCtv$Kl;f%|nXs zmFXDN1gMCf`}?UYx@E2938$cKB2h)x1Ov$rTf)e-_zTqtNdp4f&X@?`X{eOaFy$Y2 zlIrOTwJH>>Mp#{Bscj5piO2)S>WIr(?7Y`sVWEndme8 z!53G49LYaZ0_=Uj_u~*?;0>d${`ToJisdE5Ka)spbJ!o1W_$;0^%dv_YF0^6tI8V6 zU1W6*_1V$!`%56Fz91Ai2S|&V%=#`^H<|stgl`#~I^ZrDNGbWnYO!)-LMfMBSynyS zX%Eg1i7zQ-%3!`8yjOhiRf(}In9@w)NmA+{82%7PTwK(DFC{s}diBq7!;&s7Eq!!& zw6}gA6e^S?GLOR^jGj}PK5|OV$6X27gq6k5foY&es;h^UNy&)Rx*I;~mrR^~0N&}3 zgRYN?%0erV@vw%|#psA|i_#h(Vu@{PF)&53vt^3$PR-0>ETZ=vKcqTKMESXEH*bK~tfGKEeMQ(nTOBPDyX z0V0W-$HLkIC<7iIuEijRhC9| z#*0Ujd}qwkSf)SP2||pZVWGXz>W%7gZ#}rFv)ix|HFeq!6zjT;L5#W`V>W9x7S3Ox zDa?@=VSZf>d(RaT7ItU+`4#@X!k?_o97iK=U2Wo&y3?X#J|+Iwz!*h1o4F|){Hf!j zW`-`lkKnw#fhL8NTy)MNCfkQ&=}du{T}ZxSKu$THVWqupDTau%6dUFJk9v}PSCD*k zb!W>#(_+XB-6|h?+~APVMKuJ`O!9}SmD-Z%M-*f%akP=q9aC52OGP7Vk=c=}RjYr$ zi~8bQ6%io&nSew|hx!J_0U{X0!ego|ZZU&*D8 zLQPExOlrUMTTGO8yZIMfA)yeZJOoxTxryu!&&BGFp~}vwaaO5Pq_68|qYLP8&P48a zGbK4*->9h4E`pOWFy`q3lQ7UR%k2tJxQsUF<7}Hv#xS#sQW-SBKsdgV^vv}Y)aKbA z*O_!1>#u|&6ZUQ=i>*xKTH`l1H2^c-lL0Kf!n^N0-)>qCJ8y@{kOS>%gA!(EryhNL z&X#MBzK~@DkR(-F^_S<@Tf};Yg%kdc+Y+ERc*SmuV&ty9*zimQvh~<3=cCuO)0S%V zkYBzc_*W}&5TK(l0kKXG1NmOqFpNZ*%v|n3jhI$iT{by30UjCDRoDLJA@=oB?uU=3 zr?{>z;j_1Zg>pi|u$GKm=hb$xtbROs>Boe;&o}Q?*3qC5D@`Cmq^9d0Qvi>ik5cZ} z9l+KdWmD4D)6tnw-)sW#z_7q2g@9JhXfYZ>5}2h{PSkbVeinttN(e;IZTFb?1sr@* zoXCsVWCDtPgsi8q3?H@0v&9Vt} zc7^jG61=|KyOBGe#Qx?}Ip3ESHfnA9La=n* z{C@^6e1TI!|2*aB%2RvM6f53Jvc!oii zf45L@_3da|sthM!J84+blW$2#_~qez;kPC6w)7TIA;FOK8^qvq5>7@c&0 zvX4n>TFfTJId7M*{BCZ#Kt1R^srhp*at!0Rg!O>D#7^=?OKbjVDo-pny9Xyi40f^< zxVU!;*cg_OigGifevA;s=36MJsHAQPzCIXsQ|CNl<|Ktq|H}F%s^q@2@V-Y)q^DIy^W(J=x#g**>FH6V774 z&cO6t5$4!9HjegWGqWq-^fj!sVO*&`X^>8)uShYM^YybNeV0V8|F5_K)yG9&iq_F zB5$E%nGhbK&DD2kL1zp45|mbaJTDfK7rZCLfRj205oUiO?5J!r;?L>A$avAvQ6jFl<1;R`I9_g*mgFh8FF0`HI5m!B7&$2hz@1 z+}m2qKDz*LEzy&`YilDkA}3bNii^mtPj=%pWU61!$WClv#oVYW=j|lkjkHX>@96 zU^_HvetGgliY4$9(%Qqr17!zq_%Qt(6dTlSCII~o<^lo53${nA$WP@!A{k`tYL$&Q zBAOo*7i#RYDzKM3^oB%TIQx+0x*_N;pud%1GWPOkXS^!6J-E|3W2r!0{y6>k-ZsZ7 zK@iDf^)nQbEIWW}Yo@v(Ol(nJc#qw(Ejm5LzTi$RG9Zlna}4c^8GL0$}cbGbn>44vD7MP zM}*FnRHW*Zl$4_4BBkQKo7%0i17<_029YdO)4xd z351Nr`R?)*_JugPJL&xV9JXM0#Zc5RzB0hJsZ6jB)yr9TBFdpQq}GvQQZ!VS{ESaI3Y=Nlwjj& z$EUl(p;Eu~AxJ(>yWUR$c2O8YvzeP)Kkw;1lA|qZ*p-fxRlbDuuusP20!O78ybu-jNLVq4El5j{5sZ{En z;7|~WqLHTM^z?#onK;tC#$W)0O8WIbKlqw-p4g*MCB^FGpl^w&TGnV|79|aYeGBES zpnI-1w_dx}yGS&<<0WRe_Y(lM(_7z{YSMeQ>}o3{ySfMZDi_c3f-rpgQ&Ia%izMPH2L zGEKp^jTI%Y{oh|#V1HHRe(Vlltu%>TxjH)~sj4s)7iTM7P-eS?C zFmz%D_gNzlbI4p({S!>wzpamb}LiP z?$QS}x$F*$u?dO9N=;26C6I-l>uirF8`iIIssq>k!N!qBq&PWKc~J$j|>X8YZQVhRE-udwxzV=UNqB_$QF zb4_j?z9Gu5a`L}MVNIMLD>WSH>Cd}+0-b!wT@xB(yBmxRZku+iF=$Cido_~gsiCj0 zuTL-AVwEM~#{gRarCw`)52vApv39mP+TbC{1EZ~zDw0&GJRWn zdFejX)n*`(=&#FhmIPQ>lw63=dJOn0I>qZ!=`QLe85uQJI3u_SHORb*^_ska+obm* zl#px(cOK(ZM@0Qoe(OxPJtO-W;qa{Pay1j1BK=0_H!T$2<05;cp(|41PJ^O6AH8dnA|9>T=gUU@=U z^=j=hj(#KZBA0GxcDOR+$)&Jf9Ajmq+VF{8dEWh~!;mHyj~-nPw-*b?h~~enIhUK` zz?oQf_iiToAqi><32SxbASRwBGkKxGse}ADJrIpg6#q&?GHcbORaIbh;7RAd6EUH! zuC1?ibhH=glx_CSC2H#1m!|LIfxzJ&T|?VRt~xTluYY5OxI4x?3~#KRQT#@yR(NO} zLE*&C5lfKB1c(NWIeeZE<87B6V~ebXDt+<4X;0ihU?0XwdapqyB(wL35-yp%<_76+b6qQS4$rdi8q~I)%|G;2N zW!>Br7LH*t+HBU6vCzVMB__C_lBY%y<*>3zHcQRPKwS`KTd{kO1=@)8jr9dW2!x@d zN4^I@fZJ*PaBDXH>NK(5<9bg#&T$)UYzvwhp-1scYrX)&w^M!(vswmSmUm`0k_^G z=2BT%*&=iuN{>i^eYX46KPdY`yuQrt+xfvb={WOx%3$(;WgGcI7oxkq%$0n`em*k- zV*z6_g2rZpNTQ=8|M~)sh7EcC`}vpGB~LeVe=IQ1rE=B1^I43$y>K)uy6FnT@!Mnp>#DlJBR4_S2y9L5B_FY z+5V&+wQh}HR4zU@xV|;KQM7$6xz7ygJ?eCNi2%JJS2dw7kpa zE>+KVve&nkn|ddTenOSmkA3OdQE&8?-m(WRm3Px4+1bnvQOp?$7Z`TrDD z{|JBBXE4*`F@ym;J2#1ky7RZ>GCG7SVjjvW#QV7W~BN_Yw*cF0c7^l}?B zT#kul<*GyA&3{z}vWSe#^?gX$^fiPZgYd|i{)rYY*Y#Pt=zO(lr<9tlSciIx16^pN>;?e_3dLc%H`&e zx=lp9MY1pwmlZJ#Z~4dBz* zN;yBl%(A0x#d)sfT^g42V#hvH?gnr;43$~R4cZ^ z*4nDn=9`OnT3bCX8^TT;*oo_=Co(4ndV;Ly@>a01BsV(fU;m;6CJS;32HHTJZGMU= z&&RWWw0>%xlC_V^&x=dcp=fv*q+$doL&3u~Gu z{T5+7t}?dR;QTn|%AdKNfS9kF`CEsnFuwnqrvy zP7SOI#j?7E71E4zA<_<}4Iu%Tf@h5U<$ET1@oWVZi)?0^5#0rriB4g-U-Bm6%m*vE zyMNHAN0%?E`Unft1mb_6q85c>VyY)ytfpbrjwew}Hl*7nvl+x;sE%RKt;aNNdwKB;60uyXM?U|T&i%iYkYG57AKBk8SBdX35;LSMkvU!pa&YAC zV%{4Z$Psvr2{q(k#ucEv%4HLL16BeC=Ya!NS?Reeh=n*h3k5AYNu*b$*%I6$(P7i3 z60AmQ1p0Kz6=JHVG+lWl$$~L#nfM8HO=g$6xDm2II%o<1b6H6f=24f!zSL@i)9hk( zVX>+96!v#^e~u>}-a2WvHVo+rgkfg7n#$Z7T62O>n+uyQ6MsQL9)!f3wd$B=J!~{yOCzt$7mAJzZ_xDYEJCH@EH@f}eaNA%r)GKt=+#UrhLmx#@(E2sOt- z<1OZRiS}!JG@fAHx4Lf}zU?Nndy@KLd$v%w&Md4!XA<1flbLMF3(D%7cF#8yCzWL- zgj^2fYCrgRi4G^S+V*#26B)I03CNf6kH{Ok1$1VkxAKHR8znU980fj2j!?dXi6Y^N zp$XsciX+6v5TVJ>!5&%N$;@zP7$y#e>FA5he5osJ=OJTJcFIqJKg>c5m#h3mujY{~ z+ICgtUIDVgNX$gXiqQC$Pbx?Fj=zB6f0f_T=CGQVM$HMZl{kkmm$S00v_G2oUva3) zB}ySC$YB#&T}eny2*Aa4f68|j%^a7OQXOi7MH=Vi5WQc88j~qSG2Td8$Vk2wv3KwX z^iqSse;v8jd})|^^qzDFwal~M>z71>1joLH{HSwsinAFp-B)>Ww4O6{T8Y4@>Aay? z=Fa$8&&FX_XbDzRR@Rl~#SF}adS(iC0j^}ROjfJib^w5soa^y(xWD^zx7NP)rFF;k zE)2U|5EK3jdU7>xve862WEY8IggMk~kfnN`BP9|ur7ZmWFJZ}fN6R;6^!cmO(D!|D zwGb7upJM707H_G*janrq$QKhaXCeQ^T!F-tGOAio?wqWLz0;I-Nj1Znkjmiw5Yv87 zk$-}F$4pPAC2jL(VMY;*X*NUZTvw5`$*fV+F`UsTN^a zD$P{o>TC0$c;k7#>q3D%DE2qGjU*h!lacy0dwlZSoc-{@@lV&^ZNiAm-pK9|?6K04 zsbZ$eyYojF_15zOou;D-qWxyMoVik;fr~pyNtwgxBjj*{C@|FB(e@a3N*VqlDDne4 z7o!ws>&ss=_@^k&Y6EU86VF=0~vGH zW^#o59~*6#H>5%0ARs*!b0sr>u9s(w$b+QJenVnD-{yQ~R{sJV5KgXL8sP}6T%Hg3 zAdA3Q)3^D!FLP~edHKhd1obe;d;+g(u}rp*vCH{}C?OH8lmzw|dOQxl4UdLro9nq) z-|)p?ydi9NYJcY12%xdLc)d{Dr*PXR!`TPOcs#+dPvBsXm9Pea`0{&biy;#sVcTy_ z=jG>dH4D%U(EWQo<2F>b}_({#cUGwxr(WN z&-L@J|Ix zx|nzq$@2#umtIMkAfca%;5~$*lXP97O|iZMq8reOb93fxjzov*=RAj5w`w&j&kDO; z2ILcqCc||2DF3>(C%^A2N6TWKVQ?|7j{&$_Qjn?Vh z7h*S+siGuP`o+Z}bgOc9cxu)DxQT%6Vq4+_;r7zF<++eMgc|I&@InJR>-_(` z3E27&ZnzLjAin4~@#H-8h?wmn4YcLKy?} zIeA=&T^^+|h$e{P|KEwQUxX8Ie_()+c6s_ac~GZ-a6=V7Eo#{&PIz2h<%tt|_irw_ z_78s{<}S0au^h2b9M!hc2_p5*i~^X{|B~X zx8qjwW0f8wF@0zwA%$sAsM%gv{6FE57@h6td9izY)XMBu%P2(r9t$T6;>&f1SMfjy zpznV4Fff?(`@0!DC_arLBRf4R>0MJ=lrnZ=D0whHFanuI^4NCHBVP1=oK&uy3O!VR38wNpIr ztawQr9Bf3Kzdz5LKGV{Z%Wm*u-oe0zvFT}CRFd@Coh*b5Nu4)Uc)LD(il-hgL<>9f zI?MADctk>*Igbk)N6!{Fqy|PPl9SKJ((|bB1h~GD29!I-+0c~nyM20la59ND8tFK* z*&kJHEzUe$?;aDopACO`m>)dTp;ezu;}ZTwL6Qm-p!*pdzUW7nO&$}tAI)~xS8x!S z@zG!*q9Os+JF#|>MIh^>o#*yEGzT;)^%D)-q!^288N)q+*denTOE0!b|8XP`Ts5%^81xPI7e#} zJa`P29^WF87Y-oo8UeztuKLf4$O>@|Op999pSHG>xYe7~BR0i<+4$Y03a57V!iA$f zZ*XU)5z~u=D4|W`Sqlo(INXaygkUe1GyLo57KqoB}59 z1B}phH*~wL1;+Pipdg_mtqrE3@*zf_fl$y~`J)juN?YSi7(!eGB{g+*Jta+{$(TG* zNnaULg{u!%47nT?xF+BLAAM7PDV+jJB8RGIyci9 z7J!ivng4}Kv~7vyTcgzq*l$jkvt_`y&I_bfWpqonapaKDqI|&=y%prqXti1ixqobH zZ4JW)^O-rfpLv5Uk;~LqPq&{dp2BU4p3V_V>fLbz925xb*a0@oI6dAtvBV^< z6mXHQehT#ugMB>%tS1md^Y9O;q`i?zoEoGzCqGoFj%Wr9D=pRJOM}Cn& z**TQ%vn1|mW{Kjn@^=XA`;6bgoSmJak?=tb(wQt0M2lQ<9<4B4Lp(IzdAF^hR(5QZ z%?X_%O_8MK$FU13E&G94&{P)hP>a4+@<1eJMhAibIgRD=PH?0jS=n>Z5{b!rBqpP& z)(OXkTi{@fUjiozK6<|#r+|a++gy&r?Ts;n8`K9dBb4akN0ICh_GoBfwR6-R~vNnT^MZ_8Us+6wHsi zBAKUh_W$h$@$DA347^UcIgI6dyXtzM(fwnutm@J%9LXgU^(f&F<+Zzs4!{Jl=ZkL zCNlcjg*Sau!pC;qKT#DH4zYRJ>SKKH6_jmFgS$fhVZU@EPUML|r*x)5#v)0OA=un9 zN(*}0ABX*U%l}b8_QiRM5LG2OA#Qpan#iYrXzaF%V-p?yxXMYOKQ4M0_+H~!p2re$ z^)}8MVm1+dE;bd*W&u8z2kXC$2W!neYyvc)85sq~_;yvpDhsFRJ7=}P0M$}g!s}`~ zw1tZc5)yJ6l<06OW2m*8)c3V*Gag#X#z;s0g$En%`xjXIYh~ib!Y~KXXq@jUQq~~i z7V#4Lk=Mmq@!-Xg%Iu+sm7CLUa{31x=1 zw)jUyUG~CH%nGy(ZtB038Rg1Gii{PoNg=)<|9$_ln^Xh+64p0eH%z?nhiMX z76WqM9#lxy3^|bItDZApFG+2hh^SO!d9Wr+lYK`0?-Wns(N(wUeJF;Jpn@gM6=0MS zkGE1Z$uLhf$Y8-&Whr!oxm$k;l7}|8xQWHw5wdNFks}6}B({=WZGOJ`e)bW4OfN4F z!#}Y7TwGSH{g*!6&A|@F^BV6hOi{c{kU={APhQB-y{(nNS@0>D!0zz*-=@S6D7f2- z&Cd_OCA)A7GYVpZKB3}GfAsSDgLlA=1RNfH#YGe{Sd?ObkFS$VJO&1bTV;W4Fl`t3 zm~*o0fYLZPI>RF)gX)Iry) zf0wh6w9xt~E#c;-Lf+6pCM+Z@@D}8g7c}%xmE(1@OZ1s$+|l(iqsRN0by%o4GMJlJ zTb4^q1QMgIudnf7gNZc1xCA}`)(Tn`;5K?DLH&1aRtVKRoQ$Hb^4ELrJ9wyQ?U5bL zX;~_kRVfVv!^lWWk0(JAe3X)IE z$O_OL80nry7b+M(hJ3C5w%*qc)lBp_0SI31xa?OH6VI~sWHe^nR7K9cjy=Y zacI|V!^!w_n3|_WcoC8Qf$n{=tLeCpUQ{nJo{!EWVp~c7wa?b~bqvU$ay{4k*w+K> zM!{hhi~_=@lI@9h#-3Wz+tJ12V>;GV3=DfMbNJh)--;8Z6LcRU`R~OhGrqfIN_Tx7 zxctR-YbMXLy?w#HC9P3jhb=+*_xtxI0`8kOdMyaYx2sah#YzosZ%tkM5l_#@3tn%r zX)Fo~-Y$dp-_t1YS>PEobTw|i0MStD(^NBbEeKDRwK|*S_50~kCdnE=GN@iHYx<)l%n$PxlpF2H$&`=S2*T%TK-b3?GC+{u3};tout}% zU|F}d*%L(EBN~4fpodvsu_LbR_~Lv)u_Koh;L)qnYPtZ@L?3AW$Gp1b1#vVoF!$BX zUw}kZS%UCpg}`%vG}%-`ee$%4>ob8=fNvhz=L-D8Fh*X6$KxPM!=OiTQre=V|KRkNdB596|!oX+&I6Vt~WDq!Igpxobq|qz@Scf z>vjBNzwIyX>0vnP?^DVgB>rhoEp5M5P?PCfft7f)g9pY}E!n8ycVmW9$ay>_R;%wyuZ?heGMmh4AIXHp}bwa1a~CaS6l zLrY_)ouck&yG9g6Q~raXl+TcOdM*PP`;?4&n4!*yh^|u=f(AxEe~S5Z7J`f1@cu11 z0AmVlVrGj1)6N*5CU$(lB5D`}{^4>i4dyKY62G4i6=J0T&Ar=PN)>sGp z{pg*!{<{J^$G*+rkV#lZ29~cS^@Q}&MByos2tfNEztc2O%CYv7E&wi%isPun z7d+9Q)C(w?9`?Kgn%l&aaC2{G>JY${Qdw^&Vpo^yNiW#``FO#A=`%`wL;AkvJ$`@B z97s4`mhKt?_EP_qDrDSLWQ(4k$hJ>oi=K}>f{@?mEwM^#W!fF^{R5&SV*12nK~xn& zOdh8QwhGF0RF%93N6;yidt+m1*WMO%l;h*|j0~tC8118lx`=?SO4unI2+j?^Qb>T` z|6}SMzw7wl_wBT4(Ac)w*lKLsYHT}cW81cE+qP}n{LOoP*7sibKd|?MJ$vRg*L5Dp znLx4)gsEzi;6@UF?(<|Gmq=_98LeW-6AlOz5Y4;m^>P$L5?TEmWL!hQ@1t z8~JI7vBb^L`r?|(*#al=c$A3v{IGxFrR?)z7S9uK0h!-`d8A&U?yRyP!o#k6Xne;U zsD5hJq<>8Qa2Wcpc_rn5puk&}0~oZBuMUDtSwy+vu%#;1dut`BpKm49s!Hs$Qc?L~a>uyUa5K{0UIq09$;_ouemn6b zP!;Z!JOi_xhP87wW9cGxbcBcC*(#+yyR*YER=*#}D5OO#zm9W3W zgdroIYb!!mP^8v2QEt27TOkWBsJ>HrJ>*uEUK7BYgVUx^ha+ZkkHW!H{m)#rieI~m z(?X<3cL%$wV=yhb_SU}j&ty7p-;6@F$&ghCI6p&JRCd~mLc;?GE7QO1@W1Moq}sNl zf`0;R1VT>zn0O`CmE~|H!3D#p$oB#Z3n%cspEuvG*qE93(NW}x`D{LcSlgf`z1bgH zxB@ed=yCVZ^iu~tH2gkNNh@Qt9OB~1(+B8GUge(b)0p=~#}1C?1C9M|DJGcI5?*1Y zq4}CP_k3@{GMC>ACHTQ0YqigBHQQ>rel@Zz&5}=r!|Sba2mU1JcZKyA zaBM<;`(T$O3r$f`K?^g&kCYC>bF$d2?Xw=8*8dyTQ6Of6CgtI3{>kqq;wMbmCvERF z09HZwp&-BaS1$2*&!riTrG?)-JZN56TsBhsbfBLdI1#3x?NLiD9`AgyB6@Jy7TNN( z+nf)>^R`q4dY8$w-qActF;Wr8hixTD?Q8qR<3XJkUuu|5T*b48!YdgF@7YJnW0c_` z8gU2*WIuwwpSHih&lSls31bgASuAb(Yh%s}k>5B`;D~KL@5YM>S+RW>of>Nt%X+VqG1bPTK;D0tNZNhlbAP zL3lWrA{lh3j*3&CyNbW&f$9*PI$C*UiwP$c_q!-IzWbjA34X|*kdP;VWEB$}_YbDq zx)V+%GL%{!f%kYi%ub`};&JZs9+lYqQsH4t>u$mbPe6~_NIoBAWN&q(k}4C9_n z&)7#|;NOP*<-l= zp34HwplBq5J{rJh%uS_j2xrG5e88p|ncqiKE*WZ^JyXSjG7e-FhouuVU;I5oz#6wR9{Q1#)3kB7<2npW-ScyPZ5u&%w(i2Xv!yL&(QehF&+EA{n*hsGUx#yXod`}=NjxVyWfg}P17^bVFF zbNrG!zeFnw(l09hR%~%fQd8H3g>rIn?dAr;LqpGLy8H!a#%Uzb^S)6y=3Ri0=U9lD!TJg>I5`1y zAbn_@0u!iP`JjP9HKD(;lt?5ttdaLF=PSjVj+juTO}2qBUgoxg zp`|NQ`dOKfq*!+O^HJU^f{5c85dqRAKBvV}4%}akjj(o4&JSUapb*u~ojmXud? z#HT20>F{#DqiI3jH{LzhB=J7W1@{YyaM!p$XxEvD!>ZQS>25WC*c%sSm#1(&nMCmc z{o<*CgI+~bmMdL+P4)F9Eh&j2?E;u=_WOfvDjOmGB?B=*Iidtv#tvb{K(K`ns-TLmI zU(hctf^Em-t7Fq5S~?y-^1%U~YRi*$$~aIj>xD7smC@DK4RV=tc6$`1nwo-^jt;-B zTHZ|^dBPl*sjM#N9pCF5wy`xM63joSDApXd+uJYDs4NSXx_v69pPQRug|fQzMxmpH z_{jS_GYrW=fKpCj;LMJN5il|kch$EI6L6a z&gg8Z=$!9fS|?dr2t2{^g-1ml{Y%ei=-RQXPOAk6UzG$8*V zS`F@M!OP3z!|FqWZqE6$ECNJz2_U^cK0ZvwTb+VzE@Io0>Wtgw>TVCOong!t4GQP} z2^Y36=%jMz=Uo)GEDlBuKwE>Unnz* zR+y*t)z%CJstdZtIs6icn0+1LdfSGswfW_8q&`@TW9kg_dklFN{R25EHts`m)X4|& zRMdl5$gzro_{%vftuQB8bGLGE`S8F+-J0<)N3|)>RgNv!v zBEGM;_tgoSt=M+IA~6Q{SEWv+wYT{#zRoG%B=SQomV~lkU&1&|OwZ8#_4T!~!y$81 zBP4sRA8id3)Nn>Soo5cpIo9mJBAY$>;c2UU+0i6w9Z~Z`M1-o%dNj?DI-l;B<_D^j;rNnR6rocsk-YifeW>4ouvD@(KK7Zr_?RLv>$((KlcTP}yj z{8}VtZB+fJx-d-4A_av^ZI`R?{+ml4dV00Xm$N0^kLS1za@p_hrGbo4P*FN!m8BKR z55o+;s&zK)0H!s$H^fks$?_=LW_z(Z5I%lorPk%#K^9wqdl3O7?A%}6BBX3@j&r2%hiu z_-Q#gIZP%~i^r-q9qBD5bHNT2ndtY92}I?<7UeXpbRwmuy~z={uU1ALKHkRbsO3kW zpC1=bnon#%;7~eob|JtPamP!>p_sfg!?0lGvzo-Sqpe{J+iyZd&_>`{oMEZM)-myl z!G360Z6;1<-)v1oz($ptPZ!Dk@ndci5_MI9Un%O0!KUS)BQgQAF@-Q<_ zDf3%j93?`9CU$dk@iaPN{@(g^WI?n2{<(5sJn{aLd#(Ojh;Pb#hG~wS)`-vX+BTv2 zyMeimWWA=rr7G4YppOQwJKn}}tMlZ}YvJJm>TQBBjEaix>-_m9v5~d3)I^%z)H5$o zAn%|1e*AcO-^d77D>n`r5*8m3R(sW2hbC_<`cFL7t?BG<8)|zA-ZP=&uBR zDtmn^^Gft8ms)I$UD2{&!5Fozo~~I` zZjFX1c>!0RCC!B?MRrJjPWkBB{_=Pe?*NMh;=4i`JM%uGUEC|I+kW}LZ^JQPMZN>6 zdw*{xnFyyAs2WZtN7JEt=YjEOzHu$3{=cQ#Je}&alA=iRAYqrkt~;ZJkk3ZC5XmmhrGunjim(ES4b|9s%~~ z+mcg4$wD4P78_ENr?m1de9qq98Hb0gMr`%x6b7I7O8ALmJbPBENQu4M6SR*Bb6sU% z;UV2*`pWx&_s_TpJO+b<^ObrsnzBk>+spHSdPk*y#}`}<#(#wRlVga%(oG6)@YX?j4cH!B)le=uC`bIw-b> zdu#1)IhAP`&CZt>?hf5rS~&6*6oh_@y{~(`&WGRZB~GQwFU?CO?mp&;lA>|&vN|7> zr}PmMNsf+Dn#qog_HwG7R>qO@OYY_7`MrB`adJkHkza9oYZ=5apZ&?p%(vUqNY&)! z^LPXlD@d^_iH!L<1qqg;7eVl01aY>bo4nM<6EOF8oN*QXq?l-t7YN&&t*ge z#OgZY3+kT#0G2B4AFff~(Zq>JK;^`)MR}b?hpRT+-Zj|ESB6B^qe`AMPM_f_Ad#03 zd?ILhgxLww|CHZK<_?N8k@`w6WFw*k30RqE3Ys`I68iZqZ%(Vn{lqaF-Ch^+65Ipc zz{JA5cY^pJkF9srRaISXdIQ?q+x_kKj3?5~atfycieU`oeRCse@45I$VSi3cLTflA z#wP+q{o>S;31H=!Y&T1^7HhCC;HEFsmgaH!)KXDVv5*`qZf;%y0=CHvez}IpE2g{W zFL(Mm`geyL<&g1M7=Mhm*a5#rZ*z?zIhV7~f+phlY$8=zWC20CeK}(>p!tmD@o=s? z4Z?Tx2iKKKIYBS&FE8CG!-AGgSI-+1Bp3GgmIi?^Rl!2-=U);#fmf!?5!G(yRN;WIvpgn$WvP@1>lKGn-Xf8&4f&UX#875F4JlzK zeif%~d%4+Qp_i~43=fBMvi!ZdzPf-u>C&&#n)g@brg?F?dZ5$1?Yzd%VFHd2q+BdS zM1JJosXqr{@p(OVg}#a2u4ks`{supI?%Wn*(liCZF10Y{|M?Yme23&S(ca>7cbr$# zmnRn5ZnD|l|B>h2bzlgU2xpX72o0i5%H#*u&vu*$a>FZu%0Axt`I~&^VdJ zee#6wb31S{H+~BNKJ27pfBO+;oseFp0yDIK{|{V)ARl>ePgB!_bKqWu!DP1|-+GM$YU<1=sK%Ey~&t-igswz$2wXOm@r`G7hLmyhs{C}}G7VNVnC#ciWBef zFp+kP>U){@uh8WkcI=iHHvAR_I_f$){*lEBuA2ftiwMaZs&d-O|K&D2_aWuNq2SGS&6~FD_#wDpRgqy ze~wIyeO&j!i}d$bV}>P>IQ=EAWHc-ASooiBio@m_VlopEiAb?hYroH>MYY z^0zbEoUNc>JB}{f04(|L-mxtv?rw6;MmC+ji%Pzv zuR0!Mb+BIwk}@{|MRa-&-J>_J`Ft@cP1liOp-Q>x0*%(JqC2x$6@vH@j6YuRhcc;l z0+zpVM47$H0(=pnHWG`>s31l@9!GSYfnH9(lu#Z&UMavjC@1A-V-OG-7vMd}jYG*q zATeo{dPf9vl2)9B1jh4Dh<{D6vp9~}9EO~>bE<62QN$=NGp5fz6TD!mfu>I?|DOv$ zs>ITs=E33zsrIoFbwdywF4aH=Cd_x7pJ*!z>1&?E%JF7xAfk^CEau0UWwYgT(b7oD zhv&TlP&Efi=hr(9&sa{FNT_(cguFVll%G6aED!&T#@@YbtcrN#SqqhL%h9KsaJ2?t1LDOQc;L(RGb7wScqMq%#mpPN@qOS6o&S*X0?0L zWb}HC*fog#6w=U@(bqRuQQ5NH|10xvP*^ALn`mV^m*zPiG7<{2%`FKd^T>*SoN-2i zQ$sq7yMgik;l85{<_{j{i=D~y>A&4hVQ2>l@u8NYl@W!B#-FDkG&*Bb z=x85aUP+VbE$*A^AoxR4&pa>uP@&;JV&M6Lwdfq0!3h#yeew;O_6$87!77CRGKy zKi;NdaXO*JCyJ?aG<$5WcX>a@l9zkwxi(2pesfl^mcP|(tnQWOKp{*V0qUh?z} za#n=!LBEF5h3LCw4A2dDdU}b6FnxA-Dzj`aQnbff|7Pe4Izr<(GS}w;@#U0!W_npV z=3rT1^}Ms#($d&*pKdZv)#pYY0_Y=vm=-xAP>++D6CA-(3w;c*`gbEEIEDH7B;@3G zx!JPY!a+(8|tG2`9SSaA4|0>AKe{*+*Mp<-STq3lNHhAd?!teYzOuNYw z{p#a;7nG8MSp(5(u|L?8<#l`7a9PpXt@?be8cCv|rHfLL=u{fWGcW-8h3t;?r$iBL zZ!H|;`ab_BG9X*4vzW>RylvnV(x~jX#-kNQ_nR8m*o zejxw7EHt1W01Bk$=3Cm+z0=)k%I_w}L7b$jW$Fnz*(9Y{1gwqqv3g=@Vm%za38?_nZ#M^h}y2w`KweWU&qN+6~2G*8@ftSVuKZh|0ZghOC)Z$O8?s&C66B4qYr?nNu$=N~37qaLOe%o$~LsBw_Z@t4Ib`Ms3vdd7_3SY<*U#3j(T?67e+RtENi!nw(5q)G+7{trYdToxdY3l>+=*1 z8G-b6ord=9X3<3N%DdC7rgq8Vzu^r)6>l5L%f|>2^81+fgmaA6uvENWro%THbA1m) z+&2FaMkr0=-XH9Q&5CiWhEQR%i4gBnA@KA4c2R?v8lHuy^g+|bD->^pz=U+bmSBYj z#taO%DM=cR=xL6TUxAVT-15r21o20N_lObC@6h~qE3U?&p@E6ygJrqkzwrRU%j4`R z94JvUw=WQU@kVX!iEx8HFAcE{B0^`XQQX1V~9oE9r+aW zWz>=^a!$+sR&Y^Md3}832EoIh$h_p#A?sSBsQ$Y|MMYCl0oI}c!osjeGTKnNMKvWM z>fS^crRS&JD7^cJhxY2=Kfog~7>L;E<)C)`nyd0Gd;|>*wWVP(^YXFe+riXQlG2uD zR(gzeXJ%%A+zO>5XZjUZEFi7|{F?^5yPq#vUv5xd^zH4_amnmGV|{yHpYKOSGFhE` zY^ov+ot^(YmF5<;+#(2s_ef_%~csyF#hJZImy+x2}m{0FH*figk(yrIvOk5mL zV70q`zV`};<;N-H?EI!lr%b@lJtQ`Mp?qP1xs^Mep^o8x{TGj<*o$BYs||H};vSmP zJb=Wkcsnh}@k;9FI6m?(MSyU!JjYmEX@Kw+5rvQstSL{S4FV^YT9qZ3q8NoEvFJyraZK!xt(=;M zJ1=cV$=f43xP;eKzRt(XnvbI=LtAt7A=z1W%Eb6o(GRuWn7~g+S(u{f?g(cmuZ;U2 z(^TbeQ{nepTWjRY`?%%%F}n}a(g_Pq)!g*~cjR?!_*+Z!V9nMMZK*0>QvV9v2U< zE?A6(>fL!j9&60Ph3Ib(vg4G?&H{-#3>Fet8Wx-y3Ko=+phc*N{uCQerR02i_wY0k z&lZ%nzJ_9n5;kTlEi7cT*b8r12ZuNto!QDju_5$y?1n zZeN}6r&WWVj;bbQBswvgG`!2Gv$kj?o;LZ>tvH#<1C1RGu>p!v{FI`g)xn=2G|?R` zYN}Y2ec$RcXL`cReE$zKVY3KNUl=u}Pzge5i^R`(d-rk=k8y?8(_VDtAT9XM6KMP& zhZ{KisN~6<91iH#`QIK-7G@`;!UNd&o^mL~|C|BsTqd`(K;(~3IKN}_}qEwx;j4nj==4DO)x^*mzN8aqFr1ay$z<>*biME?y}mA=Hq{EL&T9w ze;4F-zCX68>eP^v>-Gb$)?PoJk>z>WA8osAO|<)HP}L!H0Wx~WVW1GtlTg=px~-Nc zcT;>*iuKi`i(+N3fUQZ_```e8sT=AW+?!h?Fi`J?```ht`2#a0rR(kD%G}(-Ky(0* zVndp6ceEFhlXhJD;fG6${r)+XQzV<|4#Y`!_IC`9r=tODh}YYS4LmN^`t-8n^>j~@ zj!#~8adZ(9mOd#73HQVO+*sl-)Ww0R$jQomXmlO72h>iDtoOw@&%DcL$Dts6%hHk( zB6jWzI-U0CwyhVqY?60IE#(De^w>B=b>JjWe=-NgE=H>=r`H{tK;7ZxCEgt$yaRmi z*Bajk@v<(kc$d@LK3yow@_JOao!W;U0fO3SKVt#q^UMBT6hHl*fngLb9~?$#g*Y-9 zK!?rUq0{KWu)YEjg^SI)AUdDyadBYh_flaqk--$?l)>S3IKj5roSF#bIEDq_A8r8c zaRX)2t-s}U#e+l`f}{{vB*kw&dA3!X|u`$bV>;KC@On@Eei3C>HYeMLF@ zwtHI~>O+DfA`Zqvs}MWDEk-YTV)BO)P#aJVK;oL4-CJ5;#Ay`VOGJcV{ZA8(-3%m9 zl~MwX#dqk)N^C{rt@2**q{cd9G?i$W>z$l$gCuU}c&=w3U{di!0ysXei&f+bypz9! z5rfH*GK%o|6p+M7Xs|i02nNc1I~Qwox%6vXjC39eBjEaG_Fcd?r|qmFN$=h6fF}p# z#G$*BI)+eiC?XUw{WP}1h=*srJ)TXQzE0tDNsOhbfwgJf6HYfmxl z61_w2?^#ABCbC()A2|`&Bbe)rUT>0Jo|K&C&;H`bLlOA!<&7>&z=Z z1;jnv)%u_hwbUtT4H;}9> zsIPY%YP13%54F}mC-*0F^c@~ooAA8PAes~{-n1XP1(I=OG9*<@%ry7&A#rOi=X$hX z_fcQNTo?!)Ur$+dtZdBdP1d+CQ9!d`wp{9bFaFm=RFziCJJ99GVga#8VsSfqFvjI_ z2?GU{TNMCvg(d9t>ds^ekHeXnp5i|>pAK)wSI=8Im_5D8RC+>8 zg4_P~&QfaX`}@a8{Cca`+Q3*+BCW2>PiGDHSHR%i=y^W_l;3<)II#d$H<9`IW=tmI z&6z27`_nm#)|BkTy?tJuo`5mC#02>zf0q-TTo|{kY%%-*GafX79|?0c0e#oTD9qfVWAd~+X@y|H=P#hK94tB_7@OzFjfSj zkD)nPAGHR>-*V#`VKRokQBMB$BoW4Y9EZ8qWtuQ}iIlkT{jM`!uHBaSH2qOAKUkz` zI|YJsZdzHvRl`2pwk*cZrqFXS`uyvaP z24?xx&T83768eoU*tIVYN!-8OV!Bu14`iuPdxZ1^id+!*rn&gs?2HVTuyWME^3qV+ zzUT-@{nB)XZnj8yNrSTR(7TtXI7~dz2d-(QJhYvtj82t}07MmB1kUc&US8u10YPa~ zFmprWL5+omGOABFzje<8k$}Ma$9sX-6Xd}ifBpEm5W8ETzUFpMXD(O)m>R|Exvx6ug0Ulx>8>ydz z<8B9&4ieN&FF5|qJ4n|;rYa9-92Gdu54O;lSExQV4uWXq(Y!{pSw1P)klX&)S zN22N}DW9L8cZXA576VW@qtp0;)Y1{K1V)=Wn>#UV9~K67Z8%#J zcP=g|CHtq}_i@(s)^Wul_(phP&Z%I{?4cwTBXrY+S(c@nghYgeSO}P z_3(Iu85jK<+=$bZNMW_{qPAf3oHx@Tw z{qcvo1)u0i@l|VST24bn#?Bq^ya3nj3wHepy5Ug3%nA#u{`ixw>%WV}wYAmb(r@c2 znTg;yC5R+>0lz;M;C+tIu4K*sI7!2jUeBBOO?khbnC$C)ZcKTjA)%@0YGNaz-Sh-q zU2S#wjO3DbBmWWG=rCWnSFb}$NKk~!ic5kfnfUpAE1kjNaO-oQTrj4z1Qs6F5EP$L zhg>Fk8y^FMo4L8RJbVC(U7;hBC353)2PmuQvhXXGE0=lqE-?oiu)Lru!kk(P`=#m_ zL`0LD^Hu6kB)!!VEKc2I25r|zy4^Bee!AqKK_lm0e0v;nUP4RN0901}E?^ynfD?6I zRAn*IDUeF@=U`*gRYf<<7oqtL>)%RZV*KziPhBVEu^+MoCGsy!>q3)Y!Mb#b*K%%V|9f42^H_7q;#NS$P;Wju+hb# zAtmx-Yjo@&{0|CBINnY%0Rn!iQ|Il?kjdPvFxrXWSfck=4@6%v=+gt(NC0_THUSS? z)KgT`-+@X(%$JuX59fUSg`9ucmtWjZYWn{m2{&&-PBowVx2M2=t?e^;|!H+5;p%3B8 z9NS{3SgaNMi_~tS4epP1KWkK7uJq&2;&%jqwO&`3XyhkZqY7L16TJ8_k{0t?_}S-m z!RB=G*!rKpqqz<-o00@vD5Pu?Df_o4fVmf$Ssk97j2L!h4gnIUXLI7;vHThuSZ~We zQricX){De#SX^{tt=SQL0t2XC$kolGK?T*Ba!VBt5M5%a0%~@sU^rNOZ>Mv<0X4Dy z7Sj|zMO$6LR9>XqzUsR+rHb`m^13J;nKPTye*Y4pl${@ zZ+pG3Pe8@37#BM`H%Ez;mv;#ihs}tuQ;h6Jm-n@)RC?YFp!aCr**l2O0`|=1fv}|5 zLYWv+t$EY0tgJ8YkJSbbab}ISL>F$;V8`}*%tn_j-r7Hh6Diny#`jO1UtYB+ZtrNk z&_`~z`36x?ya*x`#`uVu&Mn{65ABd(_i01HW84un#A91VWS0&041j%>!rbhzPQI>* zzCOP|Z!)oHl!CAHd8lOANr)Tm$Lgta89UcdRF>;WQkPb}=}bJOA{vcW^`K8fmrpkU z=7xrb4Y4p29JuD6V)ag%ID(PppjO95TzxA9kX?}R1&`b=l6&qZLBtpo6?akPvIO-L z7*J)o+S8EI_&;aBBwuaK=B`G+X*#*sh%|U@*-I(Oy<9)pJC!T4#okT3-nbFVp zs}1#>i6{*0ak00LliEaJ@EE`fOh!ZpPu;#fEjAm22WPIiygcMp*!NufMa#&dp9j#u%%z zW~;KWn7J!!r?MD7b0a`s)Qn*TMryPh9T}Pz2%G1rscgYx44+b_zB2SnJ8?L`e<7d1 zifBbRh+9;g0S5uiEtC|r1FKIk@HR+S={IN*9=nJ0>T?YuKiqWUUz|2av!R;3?=}G` z=Gy7q}Mr^9mMqZ9WKc*S@L`8&1 zT(+RDXR$N=;>c{&8kHQQX{t(ilG8m&*J`X{UJ12`^8)qZNU+>RynlI^(4}zrp-A@( z4sKOlEZ1D!W7c#&grW1*SyzkN_>>~*1{;TEK@c}pTVB4|yOCJ;GrQcfp3`(YA_FgcviV>p zzCYGZIPAB7onNK%(Q5DG38J&4>&ABE`MvSA!qblv9NnM|XlC_OiZ7;N>qDro}0Y2rA%K|nF{}fCL z;S4%zK{vz&&U8sbR&((v&`R^u(`tLnJ8DEGI7WU_IK>Oz)wdb8ufw$IZ3E>6C34q} zG*+urPNUGC=G;66c*F)V$DHFp%22!}E*9F1O3KvB&OGf^4b82MX1j}2mgQ!IIqVvH z`bn=BO~UBSej#?KTRmi|B&^BCfi}HO?==PH42EAMJ!^_#{yYl) zgnFCBrUJD~Y0;Z?$T)h->ePc5gYB??nm-MAta;gNHrp+htX~yEplHlTD`+Q*sd(?N z`_vjffVPb(-Zx7vPfrz<#|cU6)`LY&h6+@Sc`JS4_YRebVJYfgzee*iI^nXaL6fm& z6Q-{h&~=+g9ME4#Nrwn_Nl8Guds9&8NA_|+Zq2I3Jsoo=2;umrp;8BpU%DxxYRt_Dr zEClx6nR>fPOPptzi`y}@6f9qwqF`vyPDVN2=R2&ZtlIQ-%uwpUpOqQU>~lJg%C<={ z=|%!qEH@)fUxJEgR&9F7gG&hMn z5k`~UXmp6j#Ah$62ME@F-uQf)yyOZO!>j1FH|q~YHZ@u2N#MJZMNz0TTa^lljqG^V z{;ocJd|bHd>gupL&yS~yWg3d~dcQnC4t%mcA-Yji%ry!h+ezlp?_D2>W4&5|CSC97 zeBpVYb~#JucuuRo+mrpQ_6DfiU0+{@fdTe>kJsZFcDZ9ySq`MU&3Y3$9dMU%-RP`@ z5T?6D%eB|Z$PDfir^98tI`VK)5Yhm!IgyCr>Ga@s5w>&tLiUl|!( zTkm=n4rT3X*z($J*6ZaphQ&BPyRViiEG*4BVG-Rsihjjcd|mfF@vdsF-8pD^i-=%N z{<|Y%r9QiN8zv#Zb9cb%{7b3*gMho|m}7bt$+THak(g9NkX6t94qvMl_sVg1X-b7= z#o2Y2CxAHDa2+XR)#M|k~!Uz0|X~d65iVM^k96b43rNAyWaq+cs9AsE%DEP2t z+8XAw6yx85W_A{gO8K}{UMAPnx`&{ZjO5Io=~20oqaB@yio&q*6TE~HA)~90KT)Bs zt(?D{RW461%`eJNQg;5){j4cTN=M$>y7@lo(rG=JzPr0Aj#SL1|26>pk=kDl(|kUb zd}wd#(8aGDIT#|any8E63}0am7jrmS0UJgTo_latR5~54@PO6RWX1vu%d&`^f!?8B zx7V3~T#`^38!FG2z7_8xbl3Ge^d5P5*VpHpe03khYS7Q`8tzzW85Q7!^=I6b^mzu=kyLAk;qsi|H#V)J+DE?Qsfj^?3Mfu7Xkjk= z!TYiwXBrz+`5V2npeUcj5M4OklgDkAsi%MV{$`7gUM8;@WA5-t>$$_c5yQ#2;Ps$qmijG-3Ido z;0%dQtKm<0oh?VWabGlXG?mr%keK#Q^n1P`e6kRl?Qh@e$Y{nW0s2U+sfn?^72ain z+9eX#S^)xBO*7*^dT{4^QG5WHd^;LC>h<zEU~JGFl?V1h#w({1d*zCbV#X*47G2LqM#sC zem?qM&NeB?urjR2yd%S)l-?2+lB4{lKrBeMD3Tv)5pY*XaiMmfIA%T!R0))T4OuRz zS#U)3)5jAFG4aN@{Vk3WRX7NY~oQ&eu%e^9ouA$?xlf zg&c2ThUX7%xEhdvNGy>!k>!(n-y4u=G=I?0K)3VMzu|G+Mw9_6)5s2WSO_T3q%oq5 ztyLCr?+*2?PbT|%|3#-9&vcEjvnjZKAavNl!!%^_T`C_Xp?o99ex)rU8L|C zjlMy^eVVyk2eG{0hGYo(U2d)fORCH*L%K|t!6$ElyO>N}@Ok5OBmiIARcFUU zqLZp$E~Eu*jQ00r&Q%%T{<{mD!z^RUG5w|bqpXaIh9sdZRu0=QW3`N8KVK|gnov>0 zbj>jeLPbYm=Nb7O&$ZYS9V-?pK+0627 zIHu1Nk6^NH$Vt!>-!;fuCYMWRY5RB4 z?DO@OyT~;glMV_owjBfX5egT)&;9DL1Ff3xb&A&hRUZ)WDzoR4okWG^1?JVB7$$Dy zMfsjbYg?F1=1R7R1`?xRA3gBZk6EKJ$&f7`OO@E}qvpflGOKp%+I@jZdM7SG$3oy= z!l_Ii?D1u@8=!B;Z-F-}Oky(_gwV-8(k zAM0*+t7~hv>l|o*^C~%w$1(iGWt=V?`cxIWU ztn<@A!X9e40wS-Iq`d37!Zqf(#q9B8GCYp126KCg--hSsV$5qh13a|*d&36Tbrblk zSe0iX*0Y4w{(5$7aTk(W;YAsH7%j29K#zP^yOd^qt?LMFuaI#pB^WI(ZdN+4L$irH z?nzCcd+LangNVjz^Ugr~FYKbPbaq|$=lPQxQqr(Osg3q0;__JVpN9gunZZ8MH4UP+ zEKNflU#_-j`On?{qi<+*gPXp|=s&4Eq4Erdg$2t+Cqu)^tSyC-OpVA4DR1f-3_IPQ zHrP4chsd))0>7`UghKt~Yb`#$X}j!?46;7hHytH8iUSML|0_K`+cn1R<+HlSQzLma zZf0rynbovJU{CVnK={2aJUp~_F`{WA!|O{rY|g`*){!IEL^kz40hPA1vb!^H=&z-< zMIYbHkK6v8zO3f6%7D4Tfo5Mxywq_M;uwvh~;mR=ot3Wz;?OapsT|JqY5 z^IN7vL*E~3++Ry&Rk-IV!K3J@3M0Zbq%YrT2@;^pFp`&(PlEYO87M z+eVv5#ty?L0rI%kZ5NL}kq7tu41n&kv_|UzJBAX<5w(&DZUwPeJhL9G3q(@VBAE;U zE4Z$buBL;I@p78^vDmjmc;B4nX2>*{O6@M7I|f$LM*CE$^E?_oqj_axpg;8> z?BU`052%~4GAh}X-qmogudgf|eouXej^`Jbn-GOI*LAjDC{y_j1N&R+S+&`oDE>ts zTFbJ142I11`z7%7pLIB5tgU;xSn(5V3PM4r>%0L!HwYnaXS4D|?>L&um>M2m?l74O ze84uIO#jDt5Ol*;jf2fazDuB(Ug>Ap0FV3X`h2Y)6ooTz|C_##jul6+@>0hRellOD z!=1$&p;y6Pifq7+t-%Ss#?gJN2edfxa0jDiT?-HH$2s8 zU*g6LV433Tus|h+(&6x8pUy&RTgvf+sL`m*u146INao$ESSY)J?g>|~Gq>}6 z(X`1a3>p86akQh?N^zQeD$i|?@v#=H3H5N>{7MX~xmzuasb|jErLTDKhubBFkWGA> zfYhsB`f<#TF+m?>UGn?EG?W6P5ECPHLJ%|B3QeyPonNNgLO}5g%DTE#0A~2TamTJ2 z0jLNaDt8w8`HW9VwNCIeaE}&(riu zNlBb&I0AeM;=5=IZX{xWAXx;?OV8&={!nIka#)b?XwyX68T5}#?T**$6edgG*j?H> z;O(2mp>Hz7oC680UtTcib7C|BY=y&T{dsaFRDyL}T(-;RCL&qw>)ryu`R9kL-xxon z0}{sEICdf_nEgmuxsePFZIU^d=$IT94xe<^Y6A{26wou76cns8c(Ch{MI$86R-6Lb zV8A?+jsY<@!n|$>EWh4@Zjf*foYIXDX(NA6l}{3j=R;oJEhFRC?&Df?ViH^D+vaA6 z(kv;7PKKgLWb;*FTH<&*2Qyt#M708*_+Y#6IIZ?a4+QFmkEt+O(q!5j7>z8Dv?(x& zL$BTeL%XumQ<_HiUflW zG(Ku?H=A3%STjF;lSaqYk+U6IZ*lph!Jd=fzi0m~DqIhj6CRC`?WJP4qx(+K9i1Kn zs8$Zb3N%HpxK2QPWP__;N+4whg;#UByRQ{W|GiP9Dw;rzYGK+JDdJ zYNr2g6d;U$QCHY;vAYQDQIYN}u8X!dQD(i}v_LVtVHl;v`+m!5;%9G%@agK zU)>NHF0S)vz(@Z-!rm$>&S2r%gy7J)Yw$pD_uwwUB?O1Y8h3{Pjk{}bch}(V?(Q1g z1Jmbx-#@eFX6B|Zd#!q{s&?(YA3K)H45$@Si@z0{Th(}Bd1PDp2&<$4RuxrD3@jhp z|CsJk@Acl56zuxbxN9`ND|1Bz+|kpQLi+qXF7B5chPkRxzM9=0O_K6_Ytat>ygX_@ ze>_9?!^n8M7JF#u0rWW1SvLKb*ut69jGc?grsr4iXVBLBYI(J>x%gZ9FDKw-=Mh;& zy|49s#^a_}^?2Xj#aWI&$_u5A=vfQF?M`u6~8^=Daal`_}|>rUdOTko9{#B3dE{ta9v{r zAyKZ^nhcHkxu0aYac`T20k)|f7*!1FT_K3(ls-7gfW98!M7qaPmYdMOJHJi{H84i- zRRH`K%}2?NSK*=@N{|BUTi<7W&&$@M?N2&4jm;^M zB+tht2Ko&axp()m`kmc1XW2G@^U5!N1tPHmh z?$K{Ej=k4SZKH_dNea@ofIwux`ZC zSkqof*##*1V@I5aYf3J1oNwkV{1y`eaR5V zfI%6d1%HhXt-}TEO$dTB5NI3dQlA?ws(M8T7Q1QS>uAjSYf&|40UmT;^-G0pRyp}K z6o^6xfSk7COk&Vy73H9Rrfxa;H2g$P;|cks%x8SNJ+Z=PVBk)p^AoeWrXv@!A5_iR z=Se{d!FH3|!v{1Hl_;9!o~gO)&p9`yo#m4YkK^z+sL?r?1<8Lmf=yUGByDX zp}?UK6@w0KNI_)aQ96-AZGS8ObD9n>*%x6UAr2{V3Zxe}6Zr}ac%jgHOG3H=|!%i4rHW?fHXEiSWo1#F&4v&YA} z*7LNy316*PT=m=bmA<>tJlaSFdi9B@xuvkI9z=iU8^|(X#9F=K%7k9ufmCJ-_v2Uo ze%t@DoFo0x`btFBBEqfgJ1md z^w~FZPNl}2(vyt^%S#$~urC-s{_Uo)A{$P2I%8na6OmW?)obgOOwL{SUx=Sb7zYQH z6J+lHCup3eZq0yp-1nGhBssuq2ll$|+#65hSgy05ACBu`tr1XDME$nHs)ZTSEHXE{ z=z4k1w{L%!I+MD9cojUW>W$DN1?FqzOI5JHd%8~C-`}C3v5y)UppuUC$Ia^%{B}MV zI+C$lW{ympFev&986@rlGwItyvC7LzgpK7e{owZJpUs=&ZNl$&oFQJrMm#(V6L|@5 zVgwG0+F7)Q<*u@}I%Mg_eZz1|SF|G8ywSz=b;pPCv{{tGWltLC3l;@-fPwg zGXCzeAS`kX4EBc}@GYcL?J)CEY^Y@E=`HK@Eu9RFe>*QE2;2xY!_Xy0v^RoZa%RJn zwvg(TQ87Hs@!*MH2|Ib@Rhq=c#wHHDKX@WG^|lO!l{E4r_s0%;3z#Gb!S1RbxJ$i@ zg`)uXIJ-$U?v*|BXYwl3bK^M~Mmg-~A2o^!?xTTQ8hgw67NanBYS8gd{gf~f28sH3 z!(!M-r4EKu=hx@kFO|}{XQn;IuIHkNHwQM9Z=xau(|WwBxATd4gf6nOCYH~KvN))1 zpWMBEQPUiz`!$WlExpk+Ay}9jk`s30y9ipS$$$V+b}>9#-cH(&Aed|0wN=6vB6fdO zw4wq?-QrjUiL9I1tPSy2nXIR!g#QUNmi`=P{;tfjYmFegQBGx3y)griB=qCRWJ~01 z^c$p2U)t_ZYj}JA4{B)(8rJ!a{{T5t^u4c(h5yl`0iknOx=T@}fDNGGPr=gH-N>I# zeV+8eRQcEUFpA?`(BMUROzjVaad&v6rrF+W9_AVRjL*ZYy>dkm)>*a5dQoi;i%>rxmSQ(RfB`1JO&!n^NLgy+?xVDiVsxC< z6YpOV3q4+ym^>Fioti%BpUU4v!!X`oa1c0c2Px*;djGdI6h-7XT9OtfxU+1g&CJxLWd?IcBXULPvJP5xS&&3zQ~AgP+D{ zuMY|^Wo!LYEb|AGt%I&`h;RNupY8LB|isj_v?B?_j5ed_JqseP+E!;CsOqt zLdZjapV)-mQ)}9vK)g7Ag4MjZ^kC7Dl=)D(fILyty51##vcbqv!W2BcmefDfgcVug zu5F`bY>H?h<&1w$Xk)3}JGGsnJUj$hu;e02yODQ2oQ`bP+E9tnCJkz!;J65gz*fAj zQgw!h=}_AJLo&Q|s$JOi5`n!O$R3l~E-LdVO1>DCU_Og8*90WYUOX;E96A~-+jyv& z&ahOzYTI@{g7qqx5SJdU_59*LB>v|a8kTH`Ufs<0yPQjUt5Si(AN6eyI$ z3>2(iP$F@jbw)nj=Z%2Y2R@l-W2fafi%{fg7kR%tLZeQqsbeuKga46}H;irb~qrMu3FpctJm_uw|2 z+on1usz^a>=Z;L1LX9-4PA44XsLZXTP~eO+C1XUvA~&65^^Hmj;z7|4M^NOTepOX$ zi2%_+J(m5VAd};O7Z;odw%#aPwAkB zosb=kD*YN2!4r4a{8A=ITC4wERMuI+Ub6LH3z)+{STc3N3l)8uEd4*y#v(^IL|BaM zTPOVCr+CQA5A_=tl7Wr)6>CzmrMdukS8@15Oab=#;YqcjkCvKO6l~zfH0NqW-9G(r z(T?Ay9mh(ITtQ*9|I9+Dirf?%c1mhjwwC_9b1_CwjswVCBX8ohnJB+E2a>LC;s5EA_l!d;{s z{$TJG5yNsgI4CjE7%Gq~h6=J!Dvad?qg zR#hBzvf3qOZc`xHLTR`b+j*$Fx`?S~9Z} z*&@@};Zu!R21E!{@GMbn(lo`=xjbqZWn?PmnXBoBR5#ab+il20mN6;|9Zk*U6T0de z*0yHq%}gnQ73KM2E%CGVP~8Nls7bTKl{^|OSi-?y>ABtnIk{EzV)>u|+*GDmP!fV? zI;VApWmQ>VIk_RdkA2;db5OYfek zH|P5TQUFT*RIy*Pbdzf17};Jg#e&+jxFT{D$A1@qXdEEy%t1S0!H+_cc}?0Y&>zrYd~svZR5dI}L6 z<>SfQY)EV2zxu`XQh)NDwdq`3Bj$jhLO$i3B=3C_rY z3OMuGVV&G}kjwDHev2g~60!V$+!Kj^;QOU^ z4euv?3EZ+z6q1YVKmU|U>;o@kwaJJ0=SRge4(v6qz2wFW98n7jvi5kAFXGE)=pH6a zXsYQpc(Om}A)y$eskONsx|A-lbUT8yfJIca%-hzuBkk>kIcz4JfO6B>1`Su8i z*xt-L6!5ZjGBgiFFB(ZOOQF!mD>YDQXrP(Csxp0|taR;Tje{JCs;Oe= z+~QE!s{4rkT$E(&Fn~Y+61mgvyc*$6z3MJzps50EZ?}g!yet(r8**)!zlI`gK87M^ zkiE^aH6$c-UuSx0KjgfVm_1mw*x4=w_bfa2WysV#*LeBBF2TLcyVB?z0bJ$vlq7PWX%7$*0c5sYE&^Ke7`76QusUQ#|N1={K{Cxp|&U zNaVh|rUvj5x$0Ok1g3OqdLREg3ceANgtDEgVOUfOO15(@^&gM6Ge5@z<oQFvQo#kR9-g?@KD#C-PVCCCyER4N=x7Zn1~*jGxTEk zn{fwq%Uqp(j^qrI?M$~oo#-0~lwglbFT;l_K(Oo|mc`>T#ftxFjAmKb7E1ttCV`I$ z{_n-6{$T}%jpcJVEd@O5&|im&vJWbE&kz55eK0C~fEebbqe(17k>pF6q~ZB}aE}Bg zNeazYXN06J%lkCBz?iA1ScgD8uwX^xc=&_{bkJ_F3Ume12|qQIy7A`t6?nu*ysi$5 zBNjA}A1Z?3(Z4*t+q~NwUYW#3^s?(O3D4%JjX5Ys0ELaNQ%W{blH*ob#Jh3LTV3Sw z8p1|+a+wJ-cP#75Ri){Pzc{>DGhv_oq^OvMdf(CNp3_M`y=+QgZZ9$ObQI?5Y&gQr z9p_AA*0zVb7+8J@iu%95gjs~V5ImNq(H`N^tYPn7AG&abEm4G8C9PumzU!JUyg}B| zbo6Qc2o$%9jg{3eI})?NIW2e|#I_Zf;vJ^6W9x0Gn(M6M!F*?PdA1sgU{OUMnPW2s z2lAa!${2!-&GyX*TwMke&bwad2mquGk1{fe6^fRuPE{yfMTc4oX@n}Ho}8i(Q?_$4 z19+`#PQX~esak4zDEUiPQM;6gMFj(NhjG(n(pB5$F zKxJwi3!H9MUBms=Xg zyASdYSyp~xf6@Wpjcg^;E40rVKS=71aGp*yPJ` zu^4(n|Hw4;GXCg6fs>VsQWGQkq&@QMSS%iuic?(gxZ5cmId0kG8e1a*`|k3&W{3OrT`4@gBqTB*^x6MG5}Jo9mY9}d!%zK zKeI2Lx5j4Cc|1cn`!~hKRJy!a1;vR3Z$8)qvj?eln!Afjn3W8xc z)qm_a8#6oy%eGL}?GDIwQ+V*Q;-zX5!JVPP5#^@|B(29vsHdy2N!`m3eUdp;O_1gM zBo|a7`3b|K(oWdntE+o0$BgrXF+`cXAAU~_kF0vYi>q(1=h(j8Qyd#9P`?`tM2c6b z<+UTska=NQ;2)jp_cSJbAdzJYYy$oFuq~jR9Y#HPy~|ISR?IkyK6`C#jeZsho7T`e z`(&=TT#c%FGm)=dr(glFBYgPzpv196*4-oz$&I7s+o<%A!Bf1dSS`}%ut>B^^8y_Sig!FQdGJJ z0HDF@o;2=&A={F9HHT_k-4ParUedVZVD#FN;o(q7VV(>;WojZ-twMWSUj7>iGVu&r zbKePS6&&h4V%YYzyvGc72pnE1FWCdbgA_S7`kYiZd75@{d^{^%oc1Ggu+1XiY=9|B z=76OCicz&xhb9UqLy~(WJYz#!bbb-cgLBAUr;4i_cekYFD(w={?(dQ*Jg4&Ga71B` zK${$X6>@s9hJOsY*GPeLOC;eRt&%A_Y2^AJk3oy)e~vfT?7B8)XMc%X9y^2jr8Hzn z6it3d``rJe7l^lJhw)~Bqo~l;sRv=8&!zWNoK>pHx-A7_g}MC2lg81EFT^ZmhG!Fu zP`A>Qr5XUnDb|IzSRO|wG0MC-GnHulwyM=K#%`zfwUToekIaon)nN_tL6#$K8#v5 zoF#m`Y@AjRji<4{v`4mUm-8c`}L)=RPnAp24SPcy5Jl=dsvAuOPj ziUu`XQd^rs5=4-U?~kxFBLW5qIX;Maa1(#Jx3VyzGgc=P>H zboc!nuTM2f(mAJDLlmsbJznCE!KrZ(N3uNDBO=czDHTc4p{A_2qI+QPN;5nZ+uMj7 zRgN*7N9+^B64ws=)a4#JF<+{N;9JB!08)-Vx3aYmcn<3-W#%6}vG13Z9=h^JO<;I% zr0m~egX+v9Q>=}!(dgXQ8qXnwF-ldt8S_-ZD5dULmfVqriW^c+217-UWpqLG}J@Dp>M zq?xT>+`f*EenQ2!!~@W0+1e?%`kMRbnLCjZi{&8GTGn`(NACGIRKcnd>gB*Xjw^a|f4ep9*iYS^nagGei%Nm`?c0co3_->+Wnuq2oa6bW6_qt@ z3aek`PKnDKY_){4|F~-lugiR8(Zw1gI!p9$zQAz)+GOFV&&l0o=1}^Sqh*S12E>4y zgi567$)gyF5BtW=j!9+u7b|98W^2r_J5<@Zr<0l$UJ+HH4!PXYL3w<{UB$#<*i;Ng z#UoW22eD`;rLCaoIL_=49SILQ4*(x=U0-bvclf2t9SqK<`8r!JH_M8izKbz_=O3m= zSy>>hM!k>71yFs~=4-?yUI4Gml zQUF^?!kDN8M!&zFmDeZ`H_C)?B&S#sqe+o~6RjuQ;Ws4-VbM`wjv9KqFZ%;@IU7Eq z;EP{xuk7~m&r;?}E}I|V@5EW3cuGF}h#E&$qw*P2az^+yr~fqLup3lem6?Ij#=n+D zZ&}rFx7hdE{$?{zvq&jBr(X~n!@0EY$5HPN3pEw z3lS?8vN-ckorI7D+y!8k?BTz?KbdsMUPoS#qoAOuxT-h%_v>(KC))g$aYF3^oag#z z5@OhWqH2$bV03(1C&0n!OJ{=fY0Y(4LVqqz9F{ZBI;HwE+a(<0vX5d}#BF@TYKHP#w=$W@^}_ zO1S}>%JlKN=&~I-aU+?~c$9lcOW_+~FwJ7dsX`23Pr!gl+#4i52{9%PP;kqqFq1<3 zDWNrfd^b1AaapN2C-#GUd8Fw^GNk~%)L1AE12DM4RFs`(e$INM21QPhWq2gqX!eTc zFDB1DTS)^B`asbGT9wp2mTiNv;Ma_0hS9}1AiiW@H!a>!UM`ji?TVJx*4?{iI?+hu zfQXa0x|YaSWn;^xzrC$3VF@Pew~$-hn|ZAe5yqj%bWi$4=~X%^RF+K$3E?R_+_a-_ z=kHDP>;Ek;MP`(Oc26e~Nq2O+ZRWOrjC#KgL|v}7`5$WQtTX@-50)nSr0-gafW z`O)_Bpk37jn9j;r`InKa@AMt9qX^%`+zqy`J~sjyDEm23PRXb|uAxDn-_cb|I_L46 zaUqs#qVoP|Hd>cnx=f#Lwdn8==ChCU)n3fhZ=sO(^&JQ|GDFC`H|RP#dqBL=#YMpN zQXw?PpuqxDAQ-A?Qe=DlWDWt1ER)}Cw8p5fUtdM@UwCI8$W?_C05WE~CHl&S6<=Of z{&>E!7Zqc$%<)hI;$tl`b83uW3uMmZ_1&Ew3K@*~EbO>HjQDn{IL#m=(??aPPpfhb znj}~}hHIIwsjb98x-*sZQ~_ek8nT$?VX^Da4G`b~dX#&!jPQx$F|aJ9jEl^D>HTA$ zVxfS^u;SvHb%q0SPNY5gCvbdx#;Kg%BVPTAK7g-G1LC!%oq52%HkYB8&n#cR60MiU z+0Rsy(1QrW+7H~Nog)c?s6&z@>P2s8;)W$V(<)7kYh^*eo{VSJ`!Xry?(tw)4r{O( z89M<73(%7iKnMdvzD*I_(^-a(gyc^VXhmlQP*exb1etXYXL3>|(Q#zMM(pP;6jloS zK));;rP7MFZ*i?o$zEd_G_fA8I;3^8jqWO9+_p9ogdxwdT5&y3Sgy_2Q~MGW`h|b> zB)rdPy@zH^G$(0s;IF=X4i9Fy0X-}m^iCa&*w z5p6d14vE(Z+49Iavi!L+*3_g;q2k+C@;k~E5Al6n+Zr4iz(*$@^z8Geq3PJ8NRv?< z3Ci$~*WMU=DN6X5z^gU>(n0$JAgVi%}y|{u9#pmR&t)Z^+%rW|6O-HV} zdte|*f^!Jl4D47CG>;%+hzceZ-A34HjgUozR1KT^7sL@1=+DUAF`uF;QvaAZUS4Me z!UL+H+>EO@_ak8kM@1lW9?e0>J@{9e8XN13@_fchxceDOpqQofUy7{zNVXu zUda}WqYj@cqY`Ff?u&#yiwUzGC!4bmrEOT?B8m(@*03i0Tu@%`cEEMJZEuq-30*j7 z>(`z1JEoG%qq7N-oKQpF&qc#|9=0)pt20i*D|3<=M7gq}`nR0_DD-w!zKKurf9-Yw zqT^6aU~h`kL2(7CBXYu;;1Tfu6xj`uR z=uV-h+mm55ewU{ft~ejcPH4i!t_7&oARqbiFsDhtcSf5>>SfbP*Tih--F*wnYhNL-8y-XYpVnoK8RT|GLDlS1F zcEM?>ar8^88W<3dRc3&ufs2lML*J%EouAZ2^8ZLP@Lyn3|Em4?5pF)6iyQ`naGjZ+ z9^Ou$sl>ti4TSe!+rzlOr(F?6?ByA!41?Y-Qyzsrwyun>xR{D%T!6{dV4=cim(F1* z!duEIjn6{*t{9GDpwW`v(i7QMtUlCrXj@W(-PdXnAnlg{cv_l)%PsVsqr%I3Mb z*&Af!4u@q>=OZSf^MYPQc}3a7;1DVyJGE2`G0#}>)Zd}7Uu?E+6+>z|Dk@vksaa8_ zS2`^wf1{YVt7^!FYgxEq?Mts`pJ#6o$tl<{4bIi#BY6x?yJu-JxFH80Kx5xdUNbCe z9P}6*KOff5!e1B!c!Z}8oIm;h9!O`_3xD*)L;{{SyFJ}FI@$I}k92j>Zatq2fj{auTwC1+xPDyB{&$SBVU=uAf&`)z#z9PtllH25IA~_7$3ailS20 zbesO`2kbU9g)+d$XVE7?guS{zxtI3-x#mQHKvm5hk`(}LmXqK4@2NJQ12|dJ`%8fC zni-AOSgozAPOPNoO0V`OnQN;2VhTly*lwq^`|+F;Pkft8gBp&~MCZ%<~5h>7US4$FGqr20IS$^Rew zoZV8@#+$pZ)eTk5+rY3&zP|TPu!(exVjA~-)RMtakn=nz zq_pG?0x(;Wjv|6{WPI_xU(nL(yoO&4%(pg`9f3%u05VQD@28z=-`fn7!orgN){Sho z{2{@IWytsxhLnGe6&m=s+?Q$i=f3b+?GHnOL>BLF&&RWX3rH}t3xeA8c8LD2u>8`4 zko9M><5x$9-5)}Yj-%gxWVL*dKxVWrJjGSkQH+E(7T&<)mFx1rGN zWp@XDGtTJ+pS$Imqu*Y(L(!z{P&xoN*BgUo#!dfuo<&2^{uv%as# zzcv>-Jd=*SFADA!2Z2j^JScki_o!!!RnSb~Nl&`q(erRz2HUQShg1T5RA}rUapIY? z@G+34`1f6L^qEgK6#(tWNdfD{O7SRyEo)>HRp49NHpG;mD@=|Oa4M09;g&RSKV)XUy}hN1kT$_(ly>Ffan2ONJGL17YO~6& z8V;jakKhvAfrV!T)UOKCC)&$t*_KcdgJ(aDrC)%psBL7o8Vt=RcOwES3c@tQ!AKXn z6@=vC#{@)*?djz~HYx8nUst)X{Ne+B6@5P!=o=4SL~^k87~tY2q=)=Q;LoV_=siHV zbzanXwPuk*WoTPBr=D@IMPl4s3b?G=rj5sE$Q-d$s=j8al_F`nv>F$2aH!QgX(3{X zB%R|*4>&!wsu&n#=R8LV`#VJ^P$rBvce4DH_~0(MWI=Rpv_kmwW(L)T!O3%mM-)O5 z|31iG&`RNNJ2g$SX(`_v2L_k1voq{s+K}>g&--{}=X5nz*6j%PU>PS2;ZFdZ;keH| zjDC#d75o}0I2sSlS5C|AF&&&MRx+~Mj3nfE;fso9reo}Q{PTJ>a&qq(a1L-bGS-u` zTCEdE5p6yB?&o=S(X7(*vk9c5XLIr8vmd0R9MifjK&Rbz1LC|poXALtOM;ZlH#ORI zC1oeS`1w8cwMF1-RsIP6u?gAe9ZhD05DJk$PZek1ctRj5*uIbLTUp2?db|3l%e%F&` zV>gbskbUwrq*Wd(As#r{7CjuN8u@aUHrMrhUoHG{obPdagwDXc+3fcLMkQ|JhuM00 z?dk(bpuueWygvL}s`iBp?|nd;tmke7Lhk;!UeKClXaDoiSe+N&@~;oQmg=ZLHbE28 zcTm>(*JkC@4{zM3L7qP7Y{#qT_-2+U(E#tu$Irw zA5ZAx>)wV3qlZPbiOP^=+t#+-dL`sl-f6NCeqEVj<+?wlkyo9ab-P@@ezC4LqTJv5 z_Dub$U0e6{IzSxaWpKahL2$U6{>^RYdpp516iv+IdS`*Qmgm2O)YUz%0F{G)v@4I| zd8uLW8|=;<=kDo)_#p{ky=AaRm)Va4M=I|P=2a%@$6@GEB2ht{J$Ts0m_;`Rd{LZ( zq7MCGyJv_$*b){yC4iD|@fWT2N1GchwG|Z=B_$S@>#a{G8=eGsaaC0h$_AN>y(kUv{$vOu|mwhz20^nOyEwtl;uy{US9XosN z6)U3utUzyQ9*D{~t}ekCiwlBgq=RK$t9>*#tyg9#Mw4&KH_p%8SeBSIP{X_k4Me=1|5f%qW8#nYGy+L&Zg#R~t-yYq zlNaF%;Z!BV(sPgXWyFy@y~vdjs^Zmo5&_uFHFtMdaXCm zF<+`opRVf7CrN}{0l#Oze9!iVR9uILpnpudb!DU_ZLH2l#+(&rds2s{JmIj zJbPHPHTg@|xF^m2Cf8*6{?`6iA9F9dCvsAwd{aV>?Gc+b^Mk`^w;#egc06BpK!p6% zeq8_lc9pgeeYD&8 zSnnUtl$hY*48;(l|Q&29-k zQx_>yaB>Q-zG5B%qffjzPB})zsYKBan*&YF77LcX2w}qm)=i=H3V8OAN)eLY*+hJW z_&gV&c(tswWSS76;_*vIdRXhfQT+eD6gU?G#EYMw=yfuYvjKxr3j zYIvwx{=W=V}zkblr-(C+NwzcG+2^NjzIyz-;o}HCsW*M58V($ z6)mAXFNrLz?u{hgRL=TbcokK9BOW2V)`QzC(T{^Qr7}rcueoMYZxm~9WDe;DXHNB> zg7OD4To9U@B!U)VE+zPr&iny=YA39!ZSB^;J>(hN0OSsxi=B5z2q9}2_%r=%z?eGj zZ6GCS-uUHFGclo!;PY^%%&qlX8AP25TA}lMrb}OQ7p&L#gn`0Lyw+D#TiC-|Vy21cQLCl+8P{ z<>3Jq-7A!b#N2u_B@i<{?!L5$uhiz&)^_B*b^q}(=EIUFV)K6LU8eyJb2=?n?DTK5@txg6~g!B_ga|^y`IQPn8iF4a1Jr!kLnNNKkt4 z=Nuf+SXs}a;@gs5AiL`18m`NIieu=uDx(}p^ zVqL-nfi;1i9KLLiD_isE{gZ+^Rw8l6TW=W zOfO+~tl0Ak>Y;r;Ex0)P z&b!ya#wXsgg!iqOu)92O3uVuVA0Pk{fYDNfhv3qM6hivBb**r*h#yS8p&in$jn!7C zkC`DDo0Chxv*r^v;tI)hnyFs*lBTX(4*YCTDeX1GSgA@}Cr!e@&%_V=u^-*{_a~?- zGR*bM;Fl?~_6AljSiVLA?m^A4;jR%brFr;Y)HI$i}ia@V~6Q@dG&2m{;i^6DJOq4;NK4o!fEP{Es#FY^(dn6-q) zDk=^@sY9)NuX*1&Vk<{tkHb2Hjz?HiwSXvM0}AMne| zPU`Z?I%Glt*0iAZjZ4f^8X*$)kaK6^DhbkQ&K{Drpi8?i2cv%2L!sft>;W9^tC48; z3@X|~)tYMeA(2QdS&qqQOY3q7NsbmSP%x+Wsq16v!^32u%tp7_&i8L;K~TYri4;Ww zA79T`Cibv0%HJG;WemXLriLctqOZUDGpCt!H@E8DST>~cj3c@{dGXuJ+KS5Zu-X%H zBIPeJnBuX8MnP~}j1Edwc4e0Ps^Xe(JNEvChdBYBJ z6ytmn>c;5l8%Op8;#ooxAYc>hBu_|Nw|F_?Q&@tWz;@s3TL8b#g_tpPqaf+lU)l<4ktL(;P+(MBPu6i&R$>X*d z(v@T9#M@$3>(R}K0Jp4;nNPG_4!wwdibMgZ2yhWXIy}>BrkEff@kUhNr1PR+1nY7(r z3-S54k!zU0VGbrRLq>gIBT-HpUZ}s?&C!0X9v+kw6*W;$kbE1j!)||iI?*`N1|&xL z`WYnI21X+c*N52ONaWGT%yi}{RfvOev)Ww$c0TrYxOM0)&y{j8QQh`mn-J=xrD{G8 zQ^nl58of)9$Ht}xnG9`tefW$gO6cLvY0X5)(LC%*u(K`TDE{#_9iCEMu8;<~r{1%V zE0Z%Av%*@W8jXR!#(X0inI3mB^^DE%_|gUJjBBN05IQ1%S%_4^7VeE0SMH;Coq@`Y zyCLUV>1QN3G};#lLR=pIL@Yy)ilNvpwHvsLtQaNX&a429PG+zdn!xf4>s17v@RDJq zjpfe|APbQt32=OzYKO|2ow%JUYJi-|el%Ka6-^TWCev_S6v?3uikt>?s{{RO?!ec& zrmXe2RL;!E zAM@zc^~5P8?hN{+rI)=K?J_zX#aSS~ZGe6tUF1q}2&2>&=2IAjsyL**j5scKFm`$g zXM^;ZD~h*g5qnS+O_@Aq^NrCvmxO%5+)Q7aox7~T`Fxq}db59-_q#KQQeP_SB4BX) zx-O-QvuwxbWku7(AKYtELO&FVBi^l1CGi%ecTGZlilKdAxoFxc1qE|2C#W1w1o5pq zaj@ReM{6opZ@2`rV?Z?g`&=x(%B4#SYk^e`fx1T*!{_gKok`S6oN3Y*;UayY^JvZ# zgFDPvE5+qF_@wO<9u_Q_iq&i`&+GXa9(yKL6K0nlD%!aTUDl7N1-m~@Rtp96YF8ZS zs2CiC4ho~!CPT@LI=^m6q5&wYMDHqzGK%S`(J4tuC-;Lq+}uiO+)UrN@irfQ8vd3R zzdOD4huPRv^5z~i>Dqc6M;P=^67vqwOz!ni%oIY@tfC!%yNWR&gEzY0%SrJUrNJMk z*MFkXG6(mr`)O%mr`BrUbIN6(^bI=g|k!A5`+iI(vK`R zVL4G737B90q(+v{<*6`;!ES)OHh(awJd2BZ3rbrLE}FMi+icbX89rwUy|#rV(4vy^ zHxzdMdwafJR;fdn_TFDsVMJje;l~j)E5v{ByHoMB`r5P#nfySa(HF7Bioc^hvPzbB zk8ihxAKx-to$U9~n*K^mI9RA)CL0DGUCrb}ddoiyKSqh|enc8yh9!KtZr_OFS@%a& zf%*5>JR5o{Ve74kSKC2?h_7ymkEE6Ywzl+|#?_$&3!SwB^Eyv~EH|v1WsS&l2D^;5 zzvy(p0)Y}+m;int*w^`mw~w^c#Xw}XZ(s@ zP2FUs9Al5u0~2?xIEQ<274&v|)FE-M;6pKExTR+mpP>kD8 zUW%y6q(m}-JnANLd{6Xb6`cJ>&__-Fo%)=Y?oZh;2XU14TSpQyw}!EkEdJJi^YjJ- z^l=olOj8`*1fIEx>=^8<>w7!R2${7}G;5Gea#G?>^xtMHDd)SiF*{O$ruSh?M$%dW zBt2_-B(R*gm~W{J0Bk5ruZMUV<=|5M<#lgEfsrtc+rdDlt2~3!-aB{gC^|;Sq4Yed zPq*V>P|b|%@P&DkzM8s9fUm{Rf2GgUb%G!#CH@Q7%Cy216xJ__0l2*IHJDoCf8WI% z64=uJ3P%Pc!R|53PQ5Wg(HjrPa!OHG7H%*VM%#jftu2NdtYp?{+6)hp@E;a%Ahw52 zsxt@t9%3tUqaPl*sYX^(+)K?rt!&P|EvXRQ2=T^}BbX9_&|~xA)9S9NqMoRKE-h z^=14WCDMqhx+&eYJhdvzt;~Kr?xQ29LM7Ocsc!#<$OftekI!dZQMQ);>$d{j`1J9H8t$&yQX^0UwVXZ|YMyb;I$L*@+{gfM`oNZF(d%ZUm_Oz5BD)mX< zJypwn50$4%r};rz_yrR!z_QI5Nj}@bc>D&Jfkf#Gn(y=9t@oz-a+Ld_oA7W*5tchP zcCGGi=nhQ;o?h2=7DQ>dj)62l#dxG!FI5WoU9JlmJbhYx?Rvdz|G2jMh*s<9`Uw0> z+{(2O2Gf)Byxii>s07)GB3)mb9USm}Ioi$@U%tY`z1>gTT=(4D`|>?KbkB{|0ZUT0 zqfF)93sO~FGh{XriZ!EF6S=(5=+Fm04uRFLESqBY^=w9$P#CLlOW_}%OT=A=2*1o+ zxUC21Rtk4E6PnxXNt{D!`S<8Dbn4Bk_1f_g*s9wv`BS?Zfv9-J`1trS1Z<9z6*^rW zhiQYN?onXJH#12c1Igu1sCleB2zE z2n>UK7&_i+7%?wIU;{d;E3$wQk-^?3pzp*T=;;K7LVR~DSDKx1m%`tmZB#WOZQ>3< z9w0Uw!jz7o4>^Fn4kZXnYp3p%5o3>~1OsS%F`*brQ#`00N5Z}MxEDpbi?>dBw0`sX zw3U<~s!=)cCM^rJuyY@T<}*@`|vEIJY|*9~ z-qhNr^b6a!um2YS5JB(0BdgmV{CdU-6m!mT8s=6+iU87!8s;$**t;1()K^!uUO3y* z(%zPqpWL(Hjy_5OvGQ3ZGMUNZN30nJdW*1r41RCm`1$ILoWzpCT!+&zU>c5LS3fGH zVLcuz$Pi(Cbl_B=t9$;mA$hqu1FQ-u1G%yadM3BngHT)=cN%!ZS~ zrfn8T8P=N?V7ls>*BN6ZW-^3RNiW4_7?o(k{2T)dM9xk2Xy`&zGLmhnH#h~jPvvpO z4a=V}EO*?R9lytZkpQ3;IdjxvsTSD>7wJRDIk|aHeD$kY+1XgFjPsqb$sq1S16_g1 z)22N7$Rj9_8QTX-kv)mQv&o;XU@&+%dh>@B&0qY|Bls{`w`R?|@4OS^cIV~hu3Wit z!Mypn_5I3MzVhF{`OQE7;ui^ti8wPF=T_f;-+goE&P_^8`in1p30HpY+iOtZbGtox zc|ohW^X7GRcVpV-?mfG+a^+*ejbk>@24hRl$XK{=!Ok5Y{^8}9v1%K;Mvfgb4(C&^ ze&da|*REZD*IhJH%<7#mY0}rf{`KEJi!-aA@x*xCvF@>>$1GcZ7bZW7+^!hcoH=tI zdg!6I*R1)$4}Rc|ar=B71%-vsU|Hq9efzK|d+3lM*eUWyKl~AfhCj-;-g3*MU;Z-m z-gxh$UwRb2-dObprdXn|m|u{;^8N>A&YX$j{ul1O7ZV=WuV24=&mI&#v0fc#{=hal zDGBpfaX$E~uf6*7pZ^?d(s6?_Z~pu{?z)q<&J@E7UKUI5gZ@SlwUFPiUX#7Am56+? zQDh{MS;46~Kvl!BiGoF#XkuNa%mc6Sf-ZV!FjfFk>ZE3$6RWQ|GjryV1S=zDj6}{E zko1xZYDyTf37M!klPtiNdI`@S4PA&zMv7DbU`F=-eZ2>2V`?NOQ;SxED`U%GEYt4& zfaie3k;9)lFhWW&X&8eWXasEPa1byl4daIkHZgQ3jn@#PSBm`l7(`#9mBf!OLJ~%U zlnkj!c=iQ>39aB6Bl74O*u+o}c{nOVBq<4R-J{5_q2;b{@h}u~hzXL(i{^K?wm0wH zE>h<#}26slfOQ^&--Cpm^*-0~|;Ji8x zGHd}8B%@-ZoY*6##eb>RQE>)$Ay?-nl9d06g+t~R6@@*_&)Q(`g$ujSHAO`?;Yh`? z`9-mAqt9`v(!(_*lY8!dV+w9GYV8q02`Vb;Sb1r4TWex+>X5vgxSkW141*H6MVTXm z($J6rS3>Bza3QLpF*Y?jA*}#MMTss9mtL|GNcHZSd!;VYOG?p~MVa{--_cYP9qZ05 zp6H4RP7|X-IfTBy)j^=4GT716MM%bwo}Wb?RfwMr;yn7^#4{^YLmq9NG}=ZyFCyyg`oa~r%)B(T)P zRUY_DMdCDTR@Ky(mepZS4NgtfUpP%Kk@HI^Br+_#AuTD-wS0_M&BU=q!-{&p`Ae?F z*BfN1*DNQJUNsbHLxLp$H8D-;!YRkrV|$B>izl1aRtzosE2Y>dFhC+3X@Bm^Ep@y& zvE#&v6Zp!Z37%uqMdLs>C#K@69 z{pn9|nzko47KdN|(=UF30y-+8nduqW$P(}CJC`lP_LSIC^B@25AK@1@J~*AfU;(Ci z;)I{(*4EUdu&|JFhzY_A7cRVY z=~B$Z#7s#XYlC^9xw*j&DBbRuNs}i2{kOhVdjI{6P0d&^j`Ax2C=;W%sJQrRU;Ena zx806;m#Fh1y*Ni3duyT#=9nglx@vWGH8!I}>fjCl#IxjH`7Qw+++cU@UCIjbeTDgOwUPB z&ACe-4EL&dbg@$9*K3#-Fn3yN%F>&hsi{pLyw}-WONRdHa$k96Y+k=#Sse-SDRX84 z&~f5qQ$_|IYlBK9Ey{)-^CA-$EJ~a))gja(r65JuFyUTFkLwrhh;?-P1A$gw^@ZS( z+l3>Y-Xsd|YI{!dc(T30>7o&QpI1%P(vMUlU;yihg!PDoyEM5CE(fc9#BU65{8kQYyPO`c(zRr*YCMv6;*p*uTIywTVqlD<_ax)9#oz- zF&xR2#U^{Rf|nu^O}I-6sh7XGs@)M49T$s_?ZgQSC6jQkh6-RcX2vsvib;%19#=fGt|_>0V{&}ju-vhwRVOe_ zGACo`@SJg|MPiLM^N1mdN~UE2NC4O`nwXu3x*N`!!K6r> z+&FYWj|Y?#CqDL3I15?@uun=9)>acg-r5xv6;%}#x8Hs{51CzK!XUQ7Bz`=j2A7wg zhkvZWpeBqPM@`5RCrrTRjQI7!6%&K8gAfH)j0B~|wA56#^2WtsL?y#ZkShc!W=YH{ ztZ-yG%7_>a4zNbdRI;QvF|?$V6#T$2e%$!3t^n%4(7**Tpk#^ZmZ-+kdmpCgqHz%| z4s-MSgImU+NXmhsI+>nqynsiY5a72u##S)Xb=u0czl4P(mGG7+20bSdm;b7c(C(ZF+}}$?j~d_Eq#ddxY|g zh1;k|UW#$GpFG*KiDW^`iPOIF3QUTO9XUQ>&aAkSVb1tCgsMrJrb`?n?nqFC?CSE- zCX&vijH|XBDTZxP7E6bc2~`p`011yKOv1ve443N~27QC3622E-=A`?4!>-F9L9&h% zj7&u9T!>Nw)^wS8HdNEQ%&^N)sp9cTW%2z*gjf2LEdu@sOw{D4nNcwV)`XSm0eibT z3BrBl1225BzdFbCSw^MFXSg4f7>52+>$4oE?YlPvsCFgv9R7>hk;ji6oiTHUnwoGw z%~bp~@OL3KF>A^6mA;%6 zq$AwF3b-`1AEgu<^eDH9WjoO!-r2AcO;C2t@(Ei<|M9Xp8FEt;@FfD}O91o_a$y*= zpeNiYmD-0>Qcs3WqXu}-DWuyr1NRN?Zdc#6~d zrL>OJM{d>|7C%jmEc>*kMk3~ZpK)em(CgJ}pvTh{QZeE3hz$BHOTAK){YrYNDY}Eq zR*+Sss*zx1j%P!pnO=%Pw6$nL3=9?-q%LkkAeag!3S`$pYwKvW{MGEq?+HB*^Y7!8Y$k{8%Y8p(xBZmmXO4JEniZSogwLlQ$} zZz}pFb@0Aob3n%NDq4hSXkw^Ttc?}HS5N=p2{@@kYVx~M(o0Q&FDSAVWEH7uBp8{C zBtg=?2lR#vef_=vix`V0WXWKWLF(cr1cIqx!t^|goir>_BLEFVAUPs~ZW9%GG3hi_ z_lxTCE6(t3qrOmKV2CtzXRU+du+HBXM$Zc$Pr-*ZXICGJ16PM2X{09QO;TDMAFMd- zl8IrL`*`vA zfrp^Ms9oX^P3%xzSW&`@AvR|Ij2=4LXF#R@-GtUvbXqf@6$AvSSDtYIFt z*DMBgus|}8XvIqFCZf4Y%RFkdl16GG5pbX(V#V+|0@% z%+Q=Y8j=_)iAU-}lc@|HSyYBgV&*qo9x*&=Xkr&+saI;UUr8@DB~G?d`3*IalFadJ z=*3Jg#UR>>CfGcec8$FG=9{tE!WG>$l0vdnjub;AGYkMq{Nsl>BvyWg0>z_7E`<~l zKzPmQG+mM}si`;A2d3Jp8p)l-ro@A$c$F=ZgbQOYl6Ny6DJ0d32vxmG7*?K%mPxGe zGZ=No_=F;NW~O)2coZc&tE(H|+0?RkfAaMU;)WJ54DkdHx0yV_S6|;*-{_q&&6S)S zlb%YOMM?u_D^;*GuAWo2qq+l~Sh8JP9_4g-hFrZ%w$(9eD^0!g88sLpGz1tb#Q^kx zzG7wX3q-02@V8y*o2hvkzDx)`;q*cd6t4wt%y+7%Y>KCz`dZ>SGUwN*8ecTC8mTBM|xni3D1+RshNI>IJm!LlQC`J%Gn(7y~KcT~DHknY!4EW*{L+1_F=* z02QU32S}QMxEOR>%DEc3A>5_a&O}(K$Fx8edm>T)+#*H*Ja`EEDycMkE zeQ=w>awJC#!SI0)8Dj=8)Eg??dbBu?x&wv#_EPF|9i>bM&hT(C_s)pRgO^bQgN(FBrE2X&#Y|GPWdquH5z@7`UB zX79!kXU=@*J7?yWz4z=CCJ+e({-bJ0U_vr0RWMDDvw17$lat%Q^1ccSBAcU@9 zM|Q#v6BYt{e^&_yZO|vRsgdMH;ACFlUeK!q)TYAmB0i?&PkweCX&oBib*gxtHh*S` zF`G*^|10SsLg^u3u}(q?A}og@9!kr;)Kne9BfFr6Lju1pcFHNMJ%TD^)OI z330yqcc9hgY4xF1Lk19h2}Oe>4%iLvfx5c7%+sgk ztDV0i;4Z)yFVdl1)5FUc+2TdT17C=?6sw6K0C-VXmL9a?e zO2)=_GH>3V$ax-eM*>Al8k+<@U>LTOjd#5nLnDtU8){={46g+c@cKnYrG+u#LdiCU277jo^^EK0>?JP24t|he6B}o5Pd*o^<&o(Hzp6n( zDSwpde4VjriAzhPGG)8Drc%ayS2dJ3wAw%&q>}k+yM0<^QULmI#Gupv*!K#o?^N9* z-!g;tfU*B%M%9Lqk_-(xlKFKt!pScbbxm!nRK4_Uqp(CX!Vae;Cpg{k-}UPda&&zJ z3Y_mJQZk=fR}^bEz*+!XLe@I%TSI8K3-Hs$VmK}_D{t8F)&eRwqIP$n@g^C+1sO*g z_wGP;Xb<@N``>fVJ=iq-@PPv|v2dpliZb4A4VXKYoDp_7Ej19=8~*Fqn2G#)3luir zOQdApwQeY^)c|V&Yzm$Qpk>m27xu*|_*D%OO4Seb53>@eR84GB5}{TA#0DgRNqPMQ ztMQO)9i$yG>#)J2)gAyqpMk>W1v^amrWrH`jQz*EX0mF#NJ)kU9m)Lq8i7_Q>f72> z%Nk;ni9VPSwuknn42?viTGNr#Jk?IuKbvI}Sp$AmCHdtA<48tp6P|YjnNDs&&zNOp zdo((})lnx2S23u7a^jfrJD+_XLz^L-ewoIcHkANpm5(I$NBDIe9$j46r9V+1M z4gj<|E>x>jeW@aKV&?tA)+TcV(7CfGhG~MNxxNlbNqt8*xRze0@jj}P6EqSPpEtD1 zW~(m{8i@fIt1qza2ejwSUENkSCO}p7tRv(qS!1L-s9ROdsn$|Za-ud}2A1YJ$Fd!) zmOjiS9bY*41$RQx%wg!{gRQDB=(T8xUlu?mTdDHPwq}i1MCSU|e~Fa@0bsRAk@6n- z77#$23TM0ewV`do*XW#%s#U7KRFS+FNp9C=AhF#Xg=$xj5`yjBV$l#V7-wszFAK+x z!j0$w%rn7UGprq1R9t)}Gt}?v5TRhJT2v(@B?Oz=CA$EFaW-3h5wq-9&{$<|Kfrh}LTo6i+Ny>Gq^h2A z%2l!kNL8gAl}$(nHqI=&Q|$x>YY(mL2pSzqz*Xcd`!UG|~z#A5}$iv;BrM9-E{uW9{c2s*!ubWCX)7B)TkzOTE z;~7aKw8k;)ZJK|lhAq8c8|UGzZqQ4&5VTgTrV~`Fvi*AbwnNuuN1cH>L6R@RC{9g1 zV=$Mb6QoWRxwB7@!)83}TC<%~CsMdVtq#B@ZxlEQqWKLv`9-(pvK{X2yV}hTZ{2_S z`vc!C9o1gOf1PhHw{G`-KO*z=^z7BE7us5b0g?v~8gI0G)1g~NN8OBcf+1goVe#Ov zTtygO&#z-&-tEHt8gCm=p z8=IS($jdDyP}a}Zs8q9rVB__6lMBIslTq>H0!FLIzhd*R-oUdeK}kH&>jMaBP|$KD znlr~dn%ZDfOt9if;6|%G#1D*MA_iwFSyj)xu__)9y$)VAMnOsD zh8idsOGP0zp{N?azyXY4B>?22fvbsHN!%pJe7uGahe!Yr%vqJ7t@8ZBA2rKf z@bdB+GGqvJ7$8|`+FaYn-GRZNv?3#5EoP$d*q;};l! z5v&A&T(pXp+i2ByH1tV&SIaq4JfTV6x1d8H1UR#ivw5I+cYLSnUAPDu|o zfrM2tCH^uVa>@%jRXj{bCSOcoNIo83k6OM;hgPlcqpU zw1R|QAE4t@wMxs80A_x9Fg3xZkYL4=z>8LMh#wfiM2L+QwRm_iagx9@Q$ties`9Jy z@E-N$cA5+1h=0rG%{PA4NEU`TSz8E>FH*C4o$S82$dKudA=E zs}rcMQzXm_kqdzdMIoeHr+5S6gbVgx`&-@8QrFz942pF2@8axjeH+U5d%Oi1tmdr@ z#tfd0qB7{Q17Y^OeQUo5c3@!FiRuPir*=keW#Stv^OA5xM2TXzvj|WAQxIrZEY{` zAjf~U8d4-3E_75R;OQu;4sWAXLivbWHMxBoi>BSO!P|NuwD8K%$jvu;Lv1=|T$&oI zlwLMJa;XdouvSZ4Is)TYCD*YsUM2JM8fqOQynbOA57-VQTA?4gRe+XL>K|49;?lA& zHg9uxcYXHWyR{g!^!m)q%4*m|0=y6*s}hu2!SH%=YLJ2OJc*WU6IhExU(doOj8+e* zwob6cpn>L=hL)Bl7r(CVVbRWBzC?S|1>7{Xnx+#Q%LYJ zY@ziHt)+(x)hbn`Ry?xaR&%sM-$Vys)ld#kU!zZ`Z!6XEfy>0u@CQ%^NPM;^lxr6M zX`OGR*TKl{WH9jvU@g4>qaxwqfmW#$)#5#o7An11yLlWMgBYjST zBo3G@J%EP<{pSR>o6Xv_YY(aIW&>C3vSBw7?h;ah(7Ll|>9tIB@QNF2A6w1Q3VjnD zfK@|l!HV_zgd$qE%!e)0LcRmfoC;ZTaGr59jSBs>e7 z>9yPe%)^L;7X(}4=P@Mfu&Ejf2vkNsvi+4UzClChJrIz&n*~ZPn=e;@M=sye%*DR@lw66!YaQU zvnn6sL6vT(s-~JsIX3-Z^lQnDOiNWH7^DK8u|d#kxK#i-C4sy%TO{k8TAcm5c*OQ} z4GI6dscZqF0?sd;T7p5K^OIK!vOy}IjX&F=w=5d` zAw-Dc)+$=IsDoqEBG=Nxz0NU0l}>F3hwh-Wp$BL;o8pp^%*;%-lqA_}-9!lR9T1G% zPpx8DN11Pn*Q-}__^n2Wm|)K0l;hK3mchJ+IC%!0U>KeU^K87V6{>(SLIo3I6Yx}I z0IkRg#(Rg%^hC}aX@SX>2k^E85Ck4CRLMn~z_Pi%MvGia4>qkXWmWjgY-JeFX{48Q zBA3Zk!4wJALSx)DL-JwsJS1xgs)kx%Ua#YqhZ&Tqy?i{p3?9N5L%n=KZlsuAT{aS*Nz*?ch%&G_vLDulKMgUGu7xME?XXQkM1f9ytN$wFV z5Lhqj#9+J!K*BoKxr>2oWamgC0(h^i%@KI?Q$`SJ3J2^SK zc{%$AxCTeM1a;}`;p60NXfatWZ?}k^&p@$|^^t4_%s1bhsc>w}0<-?u@SKq6SNppC^%7ZbaniI~4KNGS+LFuV!jO`WRwNUCbW zD|KYn3f3rc$K$EJq?9#O#RRJW0dH4P$BAk?RGol?b**ZxJxVZ&=QYa0`Nk8=wd`yZ zSxh#>LjqCB8miI~R9%?B`h&%^_-;kL9o?eU47s~xa#tiTAZz$qBLD$! zM^PsR<2?Wp*0FX3pk*Ql#=OXS#vs?SkW}&1G!iHbSwmG?f~pHXDlJ1fLf$)1%gb79 zyxvwkZ>uWvfvEzFE9KZ^a#Q9XS_R9kx=p?9T*|KLBlCNtjOdJ^rAw=J0_WMLIhR;6gn|T0ngI@UFxbZTzQQx&{4``Vf0bA`oEppOZ^}_eRRWQ-oEws}G zZ_Y<sW+r^Ek~>+Tz*NamtEwj+*$bv=?@CFjs;XMJaN*doWATki;2S2mxCVvX zFlP=&w_OjU9XPV*$Z>>?jO2yf{G6**F-eXS=C(VG-<^P2Zjm3KXh+)w^LAz@iH!sb z$;3=08-xyjDdw{!T5<4Bd07i@MGp%3^I zyijXdq-oj^?D#pvlI&)bl*ZcPmmLM_URzAtu!6ZQ9o|ObkZ!#mXxq)^+BMbPZ1{UJ zZcOZB(+S8kmxEBE_G|ko+wAch%!0cw0 zsRXCC7LD{9@06Eq^~LaDkee@S09HlvPK^X(1S9xjeZrKyJAK(|v}CTuFEeN%8I@|M z>3g}kJJ?WcHteiTHV)x?(gQ_BMJXvMCr_Sa?LBhj$gW+xz9(&WH1n3~0p#`Ae@OYs zSwH=21v!Ql89IRn|bO(!jLFBF+lxC(?oQ7Z$OQq~K$w^vcMbL1va+u2c&0XiBT z6Gp3JD^%Hv7t9^S>)!)xH=F9}YOURDjNB&l!<6r<0-oJv0Mqfs--%R_41mA-B!X)r zMG%6P{|H6C%AjlqgB>OYgW+L1y+<%aVqY+@V#iM33nLdWLXicd9|me=08`3(p>;X{ zLbhLEyqLIjL0;>{&L5Nm>vE#*J6RN1lFs!JGWB~lt2^h7ipDkOqupMls zO`A4&@ZcLiO`ApKNJpw39fwlsfx5c7va+(8n%2#ZAg89L9zJ~7(L0VFxG6oLm59VU zddJ3N&K6qG>UdGbYlUPaZ(sx?@bf0@?U>gHV67sVkf4RMQ?H+85WLD@00^U%O)z5= zppwqBIEZ}M^4GRv!6UR~*fOw;7c>kBCc+ay2!N1^TWY3u3!5|3 zA!0Lr*MKbp%lv|dA&H!KAp}5325x6xk-6Z7=>?@wq^9w0Atf;j0uL6dP}HxL@uVTY zOkjE~HZ46_=>y2AmLQl5{ibDQlO8VEe>PbG%*SS5VC0&pAP^M6W9o$&_*)R zazw2lWIF}wAD&0Yiz=SCMh3HXF)(Un z0ICSX1f7P+=48uU==_DP7Qf(;At^B4g%AKC8Cb`=e%0>OT`Ru&3Om@08#fL+*hEK1 z>*KFq?RR14itE=q-*e`SjEqBv4)L8(US3{WTFPr2vZDuXOb=+~BH>-9>c+&@;a>25 zjh9@lw`|#+O<>vFh^D!@xgfXr+O=94?BU@a92VHv*i=$f)==Mo&3< z>@4HUuHdJQN>>-xkgy6z5|QKT^L8Q@PHSdW;!rz zP}v%f>EQ_nM>Qx$rdp*FwTx#PAv77L!)bff4Y`58V zGy=a2Gstfc27bY#K|`Bh_H~gF7j^`%+RsLOU000U5LrV3=u{Er_7cz{n@&ZWz_(^s zu!ATx#O@=@&T1nqwcCXea)vSP!Z?zc<(95q`St#N`|9iK=ggUNQ+BYSo^hljsVf|U zMGpuwwOL(ydO8x6Q>IL@fI13}9=OguU{k`8WkZaZ=wpOhp~$ybr;3L$+2A`e594K- z5E@m0s*=oE(u2qFDHmSPOzP|E-}uY>#}BkF-9S7L`V2^X=EcXeFXnyr)>nBs1x}rv zI(K$)b?F?}GxpwxW)2?F&%?`;@bI$ezf1OP-QU@{^Y8!h>%N1N6kjHPdAWsu|NUDR zPG!aSj{n80PsR7_#;lNw{M)N<#P^6<_~hJ>@GeyqRg2$Tx?}U6veNRVhQ`j$&K@3a z&;0uFNq3I+^!8GyxP(qqQ}enlTZ>DIoth~G3*pY`lXLS5PMTNrB zmz9etuH z)1$tjasQE{yZ0T+&d%%X+$ky|Y|4bOeS7zEw`3L>Sf>~mkL#g5jA-w#m?Vqx5NoTA@g>8EiY@b$AFPa9TA{1*j^? zh1^U0!PU1Vl8zF{9SN$;`nAji7}hjZ20>|rP%xpQmLn}rEj$tOf))vnVKULO!-t`f z`}gn14mKDXIehqVd`rqEH%#M5N8YfmzResRJ$iKO)~z&Cn?;Y&k(jSNb?Q{CFL0Z+ z#}V%LrUzIy#CTFt?ik&CiUJ2YdC>u*ioPezi2~VM$7|ZWYI8FTdqZa^eJf{Y=fT7K zcMa_daN3bm`%{h#9WyYge}XF}!Nr96`1+JzDLZxYeD_{4{RZ~J)8>WDi@U$we9d&^hmRSE zv6DW2KB%g#y>|BGc~>`=dJMgEa{A`$ZOcDd6CM?M$Lxt=k)icjeqr(NUwW=vx2TemtM7fhlduC4e^R@H4pMMa9J9D#JpQ~6=z{n9Yjqyh@RR-Ri!cawO zD9MNf0EO65#r*IX;-e5$YYAqrVAFNWUrtP>Q_&{>wp#IclV2q}W6IH3P1D#E*X<2v zy>wlV3x5LKo}KkwJz=n$4W^;Zo;`crx^+lL+H|v^8<4?6Y7;p)ZN+PaA>TI{lD~Tj z7#%9DDkFNjDFe(}$7`%IP=}BJ6CrX`F-{&s%;ZkBmSA=ZLNNL~6bxWGG3%3{*4DQ0 zcq4v=M#&~1WJ;@{5n46I`YQ%`R(EH+_`fY3n|CCXU`t2d+JC>%01)I+)_QDl_j(N@T22@MU60I&RV z!>Qw$qs9%Mao-dlZ%>Sb^z!v$Xdxch-Se7wCd%y7|vK3}(V^ZC;kd-YB5@@^d< zxp&9GyzHWRkI&w^c4x}gJ>#a1pmZc-7(aF7`7;++EM7Zkc)#Asy@bGkIh1;|s_NQ9 zPu(|s>|hTs_vWVN>grl=Zx1(5^Sp1Yt_7T07p(>!=Z8dlmQBZU_F^zogGQF*|#BW5bf3Q?GPL_273VC1yNB^SUxg0Cuhx?HFw{Aw?RX|Y&*%p(65;vub@#U{nEiv{Wz=iYl>Ms#xTt zQ;iC>w$c)z5kbo-Z>x$CqEf~1B(f|p$N;MvJDghO6aZQ-5IjCwwE_thc`Ls8;Ngpe zmjx{(&n#Qnt+6R8co8~;kBXP=j^_p&3e7R@-EKqXg6}I{HXeYC1dk9Ein0=h+qP`k zotlcFkqZ_q7&~_C%^ezvcpT|S`Di*cPME2E{``4iXe8?wtonxiwXva%BO~cJQEs~) zU|B_kZ7_$(8c1L=x4Yr;8!P|*%fH)_3)w`8UoEpqCN>Ft0L+kuA&?-O>Ko|Cls~rL4={;v+`;jInARxyzZ}6bxp#%6|0)?+{r>)_4Iv**(@Grhe$v6dwfLI@8I9zA?ebVQ_ZdLXtv`wtZq6u$WLr|+0N4u$~tO-ej^ zJZ;swO{X%?^zPA}DS0+QX3r>s(qK_+mB~z2${HPh)yEoZ400n>X?VAww2EPPhU9JA zCBv)g`QI*5!5C*o2VTGe`FFU;4vc)mYJZdIklS=z2u$YfkHo=Xdk+}d$_A`~fq_Ut zzWL^x7^0q^muI^lJ1yw;M*@fp-rclZ@>E(FExqdPT2{&$Ekfv3o1@P_O}J80vVQ&g z-Fx<68R(%yhCckr!&;6&sOm_@Go-A~&COl$_171(vPO&?Ie*@~rsih6%Y5|l$I%fH zbML>uSI?fZ2{go}-y&68<4WVLGtOxrvazXY_pV)Eezi0{KJMOo?~RX-x3A&$S=kyh zJCWMy>vU%3>eZ`Dii)O9pFY*R92uRW2G2ScIf2RC))o<-s7S*%lkxK!z+@9GC2`s$ z2<{lU%%B>E43jnBXYG=!^fiJf9HeFA(j&zgmsar{iyh08-n4~UM996xQ0o4fnrJ;%zgSggI+*~Mk%eN*BS<5qvU z@xs|`yf9eA*{M@p&zP2`mJk2CWb>-+Mfq1S+QkB3E>QmBWU^&PDrQ@c8$CSGzcnpuu2LMBTqyt02SR_SrS?bD|Z=4)d^#Dx z3IP7g&dz>+(W2L0du`jc?c@bNxMF(TnpG?J?B0Fl$`ym{CG+bl2<;Mz|0J$p7SEe(%uKl|Cw@Xd7k^l7ZTGkNl4*U;em9=!AKe|Tr}+Ha#`Bl#+j%w8~~@N#j==Dlf0 zGXsMBd-qGYW6tCeV+VNoSS-b;gskk#Ab&Mm@Nos@)M6(?mo7omCXL;mviE$}B}}_^ zb#=j5&h?wO_lS$@m(=sX;p3c@Q_0nlh~tse2kl`i<^Nb06N#v43c^8cFU|^^l62@V?{m;Ue=9701Cv+ zl#~LPYE|{*7Uuc@Ajt(;SrZJL`PL|Guea$2VY;#uA3~_&HEeqrw~+-JD>R>ir+&}^ zFeG1VQ~#4WIU?5%;YRg$j3mvzDbtg5=^WP19kQ>U;7GXo6#CaqxD{>I6K%n>zEDB9ssc0nYJ97X01@42;g%87xI zp@gAYrm-U)Ito0eahb(thhNzR^{P!(GFaBgZPf8a=G6bt$auyn>&e`p051kfm`BxM zRDJ0%Q`4-vmH}j|>eQ-|2z9~($rqo0zH-Hi(koZ6gU$T;^KXTrk%-EXjG`_Gy+3o7%!N(_vzhp{OI9-c%CPsIbwvT}l3!c7I#LvyJ4e8^gQ9-}XJR7bTx68$(Gwe9W! zny-zFWZvb=0w3-L_=xA708%UJ+(fjh;6!DCf)JiS1rr94 z^{Q;Vl9{kzii!$TQc{o%3JDDj3JSuj$L5V2@44q5Ox$KJ$dKIG)CxwzdWon=aCLDR zJ$m#9i#`Yl3`p#q$Xt*K!R!!-O2?M(2L{uzVqjG^s-GIIPFt;HPFwp##!2rCu;b(y@Fb-#HHXdiOy4Gm_UkLQVdb7F*wW*{ioft6L}e zFlN%QVPgj#-JiaB?T$$^#!Q?(#=na{`{q8pFD>_CUg71UzrON@r|1( z2K3>h#W=kn##XR|WSoZYqe;F6`Qw(r>WKQI4k^7t`3cI`cymR?(1 z4=_&eF(ZeJ9zN8|)H#9|R#bh6VNWNj_;xCdJaVSwcS9;sFZC$%|oq6Y-?j9Z(CV4I^3v;uE3?7UY z(rubrtM6Lq2WH-ETJHrDyoiJZD5IjILqfx_VV#SME4kyn;HMpo7IaoKCmul%!axAn za1T5F@w9>oih>mZwDbTVFnm;-#+$=Cikkc8brNU%0C+)%zM!#jadA;m(O^c#mZ*4v zxi;P{O*q)iN+<9nvdp>&P3AxF1zQ9o{1VJCLa)LWbMWz#HewTacuwW8HaEl8+4N&M zUu9g@4@}^}cT`hzYhXeZUqnCQgu_!Vue2(&mI_^=Wrc>LsuS(fsD(4@5fYg7Pct*t z@vrqBz`@cvQaV-EsVxzzjy`<<{q5Vf;s?1q?zjUx*!cMPSQ_25l9gM-O^eIn_$KuL zhDKtokV%s!VaZ4#7s(=G8A%ThkMi>Jyu7@Kh=`jMsl)B}u?KWUCd$c5e)6npXX&u1 z;+c{S3}m)ivS_aHBCV5u(6CU51uLN_y0%tUbp%emkJ5PcFU_Rnp4fnU&~9as;R6Z+EzhXQA{yV{ZT1zrOy?iHuAS5BI$X4-Xlf9M>)S%GGj9pi@Hkn1%PxiHHaxJ=_NP`eDRJ)}blQt5{r;3YN{3AblLlS{M^xnyu4fP6Y&U+wdwF>c z8#W9J5MdzR>C>5siM{L&&FVE~6`<-`^=`$GR3#f~Rq-)sVR$wk&+ADR078*qMg%O) zjCG$d`k{aS2uauUCbnV)^cW<0I*aJr^L&sd>FN--pKY!x_0<_~}P z1Ns3y+#{jKp~HtZZrFhFQhYlP-sK|feC`_F;6;nh#+}$Zu5{Ll;j%-0eL263szI}Ty zU%re9^=>XMVG$9dM~p-_B8kkb&YU^BW$TtsojQ*hGsfNB4QI}=V<&Ph=U^I2|Gs_i zm_B{bz=8bD0duiWojQ&1(oUSfDy1Hto{5RQr%#)Dt@aucloW}oz`_G6FYofsl#~Pe z_aiXO{YPLCQBk8uj>O@^f@aJXr{(tT+udAT5g0T(Qg)s=dMr2ZvcI1n&i3hd&OmsS z5JavP$6@o7KoMO`I?>JiPVC z`zu$k?%b)gB_c9%!uat$dbAE@Mejk6Wm2P~yNw$^E}?sO#L&>#c;dtf=+o2EaM49Z zM-NC&9y)ATV$WXaIu`cfBS!$}5*Ucn04K@50|#*IJiWZ;+&ee9zp^JJ9V3QkfAu>~ z)kH9tG2oPyWW_x400FO4MI?eiCIXcJvsxq`Eu>X|sv$)sn^To+C{%s1WQj6O`|4FT zO&b@uY;viJvLZ$Ll*?tpG?ZmPLR7Bn)FQ>RX?uCB(+ms_+Mj^>yC#STisIjTh!^^8{aA;IaH-G;?;g`L-x>{u3iwg@%ON#NJ7TC3GXm}*fc~@7{ z6rq-u%a?O1HTG-5{T*x(QPIABek?1;Cnb$^yww-v#`?zmoXb@e*EHB#`~v*Ch6THM zxHUC20lK{wh=R(FQYySC~!guI{dw zcHJd7=&_$J{PVB>`O)i3qPm6k>f7Vu`P@?{GW!hZ{qR%w#l%OEPrm^FkKg+8#Nmuy zeYz{%f(qrKPRr4=6Tt~$SRBQ}MZ zKK%13DJ5a*#BpzZ@afhaDVOtcRuw+J@cu4cET)Y!3lBGsz#u=2pcDwoE31F^zyI>{ z^Lh1!pGJp=V~1}17abWH7SeU)H=FkzI)eG^^XJ?(Y24_hrq->xBf~;{y}T_#6-Z4* z0EENBiW%oJH*DKdRi&h5Wej6_Z2p7T1bp+(?MKqndc=31K5;TU@=ic?IQ_)7)E&sE z&bw!BfKThj;ygxG*6_$jN)V=`BJuHHQraO1JVMbrq187NqA1ACJ$w3e?xjm!o}Lli zy7eD4$X(e&o%2X9IZQB}y44lsCy$@V&AQmw+=MroKFP@eUAwl~v|`Fm9y^hH@gm-! zz}z>vzkipm#?D-d&Sv4g_uj*=jCiN|;SYbP;|ni}@4WL4dI?(<-_%JGEtRFE>FCKz zmvF{M#m4mM-``KV5>@VcI2N5cVcEE4i$75oik&lu?v3P}+PMuV+1aC;A48 zS2}Ndi)18mXmk#rMZn*)0tSO%<^E4T`J}q4>d2vky-evycImZUfH$>I5QcnbL_|>P ze`G76MOw`Iw5(R3GGf9oa#YYWw)}#E?OV6ubpty%Cl4Bgg(C+I8FFwxa)1X4i;6-+ zT7TgK7&Ge^FJ9a&HZ~?UhSHJXK|T@bJXG~WRw*$tHZ~SYWQg|d+57feZ=F1qfixcu zG%|XJ4jsbg-4Wqo*vK0gd_ap~bmz{Tfq7EjzDU&}9SJ>BIawF7KK^9!ri~k~lw7H* zsm8jNNC~B;?pnKU-CZ+hKK#fd6q|`j{DV6-26(>r{`-4&?SoA=|ZTxY6Nuhn)j=@7=q1 z&z?sgd+g2`Gf?&A7hjZKEv>G-*4)&D1U*uXiM@Nrpc7+bPi17FTZRrDif%!c6Tx6O z)&9NvKK%HjqeqUEl~>eMS0NZrPfs}Ny?*_Ig%6@REXi0_UWN(yJ9q5BCPkqk!K+rU zzIx?Kbq#)V?S$Ocp1pe>eDJ|1o_InvBgwiF?$AZ=f3OH!s^WNEtF3KrY(x{hJUoy# zMTYmOXP#jxw%p6PYgVtu9UB5plKuPlSCo}u)k}0Na><7e9s0qOKe}_qbVgfIRlRBB zhQ*(Jf(&d$dHJ<#*RXLSPK?yl)WE=iveHtk-}?yc;Gu&beDJ}(efuzQr!6oquU)Cg z(xyN8^s`R~4Hys^5fK(1hLZswP$Q>| z1R|7pF{gAS_GQ|-ZQGZhe~#=klEMi4>Xoaj*R1~PtFMq8FR!R*tgG|z_TI5=TWoxH zxR`bKEOt0DGc%ViT^iG^TV-`MJR;?-{92HltiM+T3kF`8U<%F3A40;4K1^m+1fW_H z$PsU=gBPmEoCpO0c!QA$hS#;oq7vEyk9;Sn20-n_sfalNn5K>0v@uP4#E22M$TV#h zwj&*>dicJ-N<10h_c5#O$Bhpkt1t(^(F51D2avU*VR~1tl&o8}bo=IwSBi>i>#zcM zGxAfI{xf3ixP^~B5ocZj2lF7(P8|8_lf{|JB-xs}`UXr{aC39V&!pofPnmt+{OFk2 z=9Z=PU?YCONW1+xbzWlpU<0dcw#h``P(swd{w~)Y)Wmn1;{ong(2aahNwv3-S zY2FVPgmsI+c*76>@!pC3$DaSo3nM0tadxqok+f~~rcKM&kC-&--pA($n}+Y1Uc)t5 zWwj8N3{cacP0lXPqbCfR^}zIHi`TCDeEs7;UATA0q5R8*cg~sIt9QI#P}{9z(#+9o zmui^-Y*DXVT~IEUT`*#=do{Dyu$v{VPm63X2T(_VFqzyn0~Q zk@IJ=rp}t^7ZjitPOGksSq;Q8MHV$JPRJaN95!V0x2b#f9l}d&|9*XZluVQOVaaGw z>*VC^>468(9jSZnoH{W$FbG2|fkCq{(i10GWJsu+n+t|p;y>Qj>azt8W->MHQg-&| z%a`Vt6v-IUOddM{53~D^9Qyq0Ws}EE7&mgXkd6d0^ZdC_mwg%BrOTa@@6bv|vN*{# ztRPFlGaW;rvK9XI%FixkU+B{}dHk4Z?(S_}nw@=N-G*g&mzX;F&OS*4U7TGh>=Wq+ zQ&P5id;2{w@3E{)XSZ+PfFBzkoc}{l&$e7@UH!F>K7Iplk7T7Pj2Jd<_^`3&5iVkE z%tT4$(4|6xK-bc;Zt1ec@4R!hs0i;`F3z1jynPb;^m*=;m%AtS>fG5h*-ALEOlS3p z0|)-~r#~alSW{VvH#!X03XKSV=EWCBO`PED-1KzxJn`OP8*-VIqF~n-@n- zn7~$IZH+}EHmk2yLQ|-?V#SK%$B*MB6Tdjd#>SFf5CB6u61jk$rhz#(CG&+l{GnaD z|NY0;kT<+mQ`79!;_c%b9v%I&mwq#J^k{Yi5yDoA<|BgOPaZk)`k()ll$!M5o8&xIG)t<1yBlizH z3Y<)XRL<#Br%t4&1qTO@8a>9x%PTxAY~1KE`}XX?7uo&;2k*XXCIyCpT^JG=7!<^k zHKd0C4KgMe0ElNf6mdsIChN7={)oM^alwulK77D{0m$Z@NKfCtckhLZS@4eBB4faP zvd&~)Sy72yZlFY{Nbw*Mw(P5~%F4^JhW-N!9ta8y%DQwBNh{==kT=97_QH!Vf<(15 z!GOI$-+uGWjT<-NPLi0EGvGU5wxVShB zl*6+@Xh*$=A@8|y)24v~2aX&uf-%c~b8~W5 zto(Z2nl-4x@IbT_=LXWR-+sH}#PMSoMTqoq>xjSl1|-+;MR_qh8$TEIPV7Bo=n!<* z-rc*8r==nN8lj8RafH#IGKJjb=NLr*0EbG#yKLtjEpUYTl)J4%)0BY57ykGKQ44zLb-F>hx(GDLjE7%~apez&1$C$;rX5q_|`P0t3-@ zoaD$w{`bHCg&v1~-u(H2{sCDRvk=Dq{rfO$2T$yG-#r`8FgQ=j%F1wd;u_4moOkV7 zO?X5^V$Yrd@{fFcznEmz%Rxek3f_vBMVz_J-%Nl<5g#LlW&%TCeL_A4Mr%9_*_Hyc zD#%eWzg7(y;G+*dz=?|A6|kE&Mn~eS(4gUsfjQEVH>RH*UK~B(=z)&!0ptN$G6;{7 ztH1hU$%h}5loTi7ckb>9PMtedm6u<-m{pvQm({imMe4U(KaR{&)&KxN07*naRC)j1 zL;Lo6dwF$F>>U*ogP)FbFJ_(2$jHgMfTxWofAZ6Wp1nLhy?o8%!Z0*0KR*wTw$ZV% zKHkb;F%M67SJP7K`c2YGRKZIB)fJDJSY6S_U;MVHt*O0u=0ZhTMbG3!Jdc6|4bMJz zG2>`j`PItSZ)Ps0;)3GLw9J?u@r@1j z{$NCSh{|F+%4D%KpMLRmmq3eTXj1PUgOih=cx2wMU;Wc7fBe4_Co+Z(PWJKj&bgew zV%56b%lQj`eovs3KehU#!g^in5u`_a!cU%fveNRc2Cc&mnXiBFHulgA?%EaeZqSOV znwl$DODn4?TTH7;2|Z9>TUSzg6+?j8`;t)0iWG^p69I5B-V2w`EnEIc-;)D|4jjp+ zBlqr4`Rt2DRh1R+n-Jg2#k=j@V9VC^UoBreY{=-ja~5P}o%>?x$G*Nk_ujpLPe(S? z*DYP~Nr`SY7SZ6PWyql6t}bo4PSpk~N_PIxmBZs&T8`}9`=3AmIWs+d$jFgXXU#$u zb3QwPj_!!Z;W6ukQLpQjx@g4d*}vuC@zx^CaRd0$FuVP5`w zU$2Po*#q;=iVN~zeeUP!M-8=LfL{ z3x}7k}40jAAResxA0ny{=j>VDu(YWdt~{_6&l;plAn|N(ocSRE+Zo* zA>q!sb8)b?e)CQ0)~#0x3O-o1GCV3u@yqqRQZ$obJo51a2mkQ=^OrAWcMTPPe{F7T zO5M8c-06&2_s?VF%ma^7%T!wfZS zma=Nqs*JQWOonCaq1zFt&h1*f<=Mf`nUEtbTDlu^Ms|wkdse&5DPSFz;-aEmyLMxQ zT3Bf4xUpk_MAmQM;6Z-={umXwdGjU=YICuDT|eBuwquTVT3TRW(Eao7f8^nZk&(ec z$5^&4Tel)hh&j}hj+Dd04eijOgUgmKMdk?8Y=8NSUm%y~js#*=6(+;J|K58T+J>#D zM~xcIhPDX|XgCr>n8v$t(`NiYGJVDjtWSy=&G@>(x9z^Yd*6NUJ!IuttxUI0;~idG zSGQ;1UJP(^c5#{azyl9I^f0!X0o7fzX02Yi^4<5}$9&!%J+S6=FN}cOvv)6sjNEI+2#Pas38NvB$tG;PDNDEGoJJ z^SFQV?6a7ojYCjaT=dx|pI`v~#jGqO?98~W$xtXpMO3uK$&z{I&)^s=bwM{(MLFwF;aCnrc@^;_UzFE zlc^CySy@FufFBa9Sr<^c2t3k=Ik}e$3kujS$x0*_E+CVbl-L`m0p@_OU$+i#HyAAW z)YDI6*#S3KH{`@ejT(jJ5%Am9_HVyMf;cWV7F77)%N-x&d>1ZQ06m%+92&~1?E5Pp zV*7Zmt5rt@sl%h3tR43uM_RP}_X2*y_!& z^njxWZdMPlk%`46#fJ|b$hma!;U7Icd)~ajF2Q&)Y^-a*j1pHj*YHTwtErPyd0FX( zwQKh6O6@;j;QYs)=+n2Kr;iWj!d$DW+Ou=V=O2Bz<(stw1`m#mik|!6!_)7&n>A{~ zn$^ob|0E_Z?%^N)C_bSll-R2zB;5L{sithHp^jm1-}Ex;9~AKLPamtnkU;)}(O6%u z$>u(`7Bf!6%GZLU%Km%OM|8q$rQHmZA%n22r>9^9rYBPR@deQ|hL zRG6oC>v*?_ZlN##DrQPD95_-qu8L?-d?!tSrn}_>zuRgW#$+k<+ch{-iS`|T4&1y`yos;|}5A}{P0>+DW%c+?a$K9 zXU}K7{lTZ0T+`;)$?1W)clAu@e#fNoHPtmAf42PNCCk5BwHEIr)iu|=Jw1Q+)DsWP zndR@>HcGVBJ5eBtX7L$=4l@Z1P7fu~kx_H+x(9DTL<0d2+;4vHE}nxSKmW{6;=0A) z#efI2M;0OxiV|~cq&*W7tE#JZ?cI(yw7{U&8F>Kj+Oxf&;4&WWccyIl!6VOlx1}Rn zQc^bO=jV2h?}_wKeSKYNSxG}C>YpPJHRlf5MFIKjVeQhOX%=U(cpz}-*L?XUIxQ+T_V4e%7ZV?kLpJ%YyMFbfA7z|4 zfu8*FOE1X*;DXJ*koCcTUeC|Y#w+oe^z>GY#eK5{!<5cvX8QXD{OW)Hz+Wt#oH_@G zwwVnX7^jT`WN*S$x=4}S22xVSjx9Sf~s zJq@fx@y=Us>2M zHwSRC4#$+?UAuN+#2H4kt#huM`(g zoH+3Zk3Z2Jze|}~f~Gm&I~OUu zlelf`Rt%dhDN%mRPCJ>tXV)%FOCL6L=z;|gV6jDpM5+xxJm=-+E?d65t-&TqhzdHK zx4|JH6T`x%PoIu~kZpk_B!CCKRE??q^$jkbZ5z>GO6&s*79f#Ifgz!ZQK}fyh|!DS zL{M?&wtAmEXO6eGPs!CQI4+2!M?wM~2C&vD&hVVvZ2y41NG;)H!-Gd` zY#cVp$;K6xosBdTmJP*7N#yiLjU0*OH0by4RaQWN`P}>N!&dT`1c_j9O3w@k00Z)& z82XBYGNmJN`Cw4$Q%^p5=UsPUu}c;?qO$KJcO%au)Cxr&$z;`Fw5(M1P~s!=$QtHR z{wpJS4P^4m`(iTKs8(iH)o9_#DesXO_*$GA%5K`{&ar9QBe4K3o=$JEk0TvG>B2OMYf&pT)=P#H#<4!N{ zwmro{BEoZXb5MC+o|M-CsvBPnJA#>98;(L0IQtO`1#BY9hT z7lyz-E^f||ajhv-Xqf)QzGLYJk7F&!`A;nzGj*J|Punb4o~=V*SNvyD z<>V9`8W?OjJ`qQuM-qBWS*A_cIGeZa-YW+GF-`KY)>Y~WfCt&6fr(Uv(6diMLQ*{D z5MjSaxNvrL?yT}?VX{!p9&?|X6o380Z}E$LukP_Y;luOpoiKJ3cALTIr>KaqcV7Qr ze;?nFAZ3}zApd}W{_pQF6D+xJ5^BLUZq%^P-~Uf^bOa`D$Hl}fdh>7BYAmzHChG9e zus}b5cX!u0vt|q$ker))xuB>BFI>=qpn9~0drCjkD^KLKg`|TqqMXn>k@`7M7y{s z3$|9*R31+|gb@&sGf$_LTrI+n3K)lrVqp8Fwg4pcIaI9q~)r`$bj{qhw%Hg1|YeL7QeE)?;bWy@eS_rZn3$Be-O5zurE z3!VSiqknnnw=0({dG^H@m2~9NW!NR)z6T#1K6*436l0hN9)Il5zj^8F&p!VNFkB}n zq~@p;xg$P!hsT(Ub?eqWq&AiH%{Sk~d>_mp!w?MSZTHh+9q$ zrtl**gAq*pD&DxC<>uyMr7-+O);=U8Bq1RIxe2^l;Izkh7OaDYcT_k;Eyn$!2QVzi zp3XNM^DP3DkF`!<=&q!sBn*(md5e|Te)ROyI$)hDcI#n(c^H5D<@Hu&<&>v8kyW$= zV3Ck_5*V;O%#y2iDG&)pVPjk#mY~J!@z`-=aYR^nAt9k-$Bx^&Webui80*KIIihN(GMuU;uXTt-!xp4_EtFz$+&V2kGCs>kf$gNF`bS~2o* z6UI-#5|GIG;?(up zeKu^|P*hleZE+MUGZpe^1N$e_FVL{UI6G|EotK*@WIW+pVKckzmz?~&-@S_YG#F6` zE8LBX3JS5Y&$@MMkz@oCV=wWa8SaDxYj=DhnKo=;N~E5{!@@D%wxO;8Sy?<2U^f^{ zKt{B~M~=XmK)%3_77ZId0&~ELkOV#|C^1y<_rL!=T7;w|4AF|>fNA55y3)AHN0|U?~2o9>YbqZP{8`UV*RC5=+_dA2?z-vn@pIY@R#whzkHS5%rK#ma4OBXD9l3&)g9_i<-`9V80~RS@=s22eC2H_>b9odU^~W zGypOH3Z|1&a=+d@ZAfs}P}2-4D1BT!M+_dw*UGZLuTTG^o=H8sV?9%xobH$;u3oJEr~hl~?d?R$X4^!-g{}_|^}uyPio&?|;1l zWxUO196kEli@z~fnl6P-o$yMJRkx5t=sRFQ`IRf!BnO&^ZqXRf+LrU;JD-;o-Zpn_ z%+@5l#A3xRbok1ZEAjS-8LVg;mUxSejm3e$_!Ec3@qar}?tQs(ILe#mCbdsZ1=}M=+%oojTzvL)r`XG_1&h6$2?9$+r>9TO%ik z7yihoC@hFc3<`lVB&`cBY9$>BkEZBZFYz!f8FvPZfWz$w+iE}mt6wn~>j8EHfd-3G zV(j6d!Gq_|f1u>b738NjY}|<9c*ule>rrG#Cr+9)YUC(iA7vpRRo8r5m_;T{6+q&OGq#g!UO|gvCDb0GRjby{Xta# z^D?qv*!kzonX^lmEyePb81z?n?HUqlI5&_5CVmpIG$#fpvac#e3pRWmxVpMvm?$!Y zn9Yr`h**~i`Aww%@I^rJg9J(E(C{!Unps+=Le2;o9)0NT<$Wsi^tL5SjvYIO`O!!z zVv+~$$+&^@O<-Tp7+Q!^4hvN?a{%xk;#8KdR{B6mTw<-s*|TP0CUshR+P3Z6jvhUV zAFZ)WDo)_BW5$dgJ%*i!Fu`U`ID(m{GqE;k)15Q%s}*LT$8?JUEOqzpUAuQ93yCZz z7ShBAxnm5CM8KHJfQteFV7&1^|NU=#YqWilIXU5n>f*ve99$e?{8Z1n6iH7^I6#+D z5Ay=Gg_ITzH)~autP%`bwOV>J>-GAVHA-eDy?$WSCxj;|eM5rj3syt@yJJ=x!^m>hl`J>cj8M-SN416@MHe0+T?E6P_a`5d{bp`*tH z1$D8=yqM_na&z!oF*d8gqnXSf6%&hxwlk+Pu&y{B6l8NUF(({JVIgRZKEv~;@ahUP zS!2OvDqO3&_QgA&?b?(Q+q3)pCl~hWm)O~I%Nn4tpI}j5Ay%a3&8sV`@HED2u8VAb zl($Z<+6C6lZW=;(w$S>@w;CUw!jZeQ3hO8U;zZ;akNM}<|8@Apv0mMK{Nkx+dv#Bs z#2IshMUIf>tGQMqoP>3C*f>aNdiz6Tou|;cI=c=YG$P~FiKE93F1Wwe(^DCweDKJg z+Pd0tW2eA#=GjvR5AT^aWfn_E9zAxTuD)*Mi1ApsO8YrhC&CVIi&MenT-+kCpdw ztp}o0r*0^A>I6(lVPQykIR8dL{2e9W_jBw%jMpvf-hd%Oyst%z7U5MKvmhxK38%=2 zVH;Ld@##psp<$Ubj1^+pw2q2R;c0O~Ph#{3l1i}&3B-&z(UT}5a|iFO=}4-E(5Vyl z#m4-4h)8@0r@aEin`?7(W=0179^SX_@!9uWtFC67{E44F8`&*JfoXm0y}u!i zf}{*~2SOGDX+)$#kVin)0|QiX!XVp=H$x1_!fKv4DClGyAk3pjQe*k@aYfL`AmJb~81~FT>dM8Xv!^@#eAPw(b1&2W4+sxG3A9yk5+)HNCx}&mkoV!# zts;_2xHe@y%PPuxUW0x^Y`902SCn5nrwpv41Il1nU=bNO+`Ev3nmzmOZZX|bQ&KQZ z7bEYGjKCgf$BrDq*y{1)#y|PgkFdlf^&9U+mYeX~TAh+)MDsDpK<+pE#Omekg*++7 z`ym%h3|Qk5bHk}93!qYsp{S^M*^1>$zxoQP)$-DE3^u@^&Y0-vp1pcu5Fb7Q6#~#% z4-{aG3dG(r5ZsiZrEL=|hzI*hD?jq7r6Yl%CKnVGe)Z*7U$0z&z>t5$x3&^kw;0S` z$FD)>&1;wttukT7-A${R42BbO>a?lYsRsG?ef##Or=7rJkw};xH*GzHk*AM6`Y3i} z0vJObd-d*(uK;AsG1wCq2)>1r2P6+3Jmlh~OUMgjuq5=zKVp0?=B@Ycm57WyjtFv; zFvl+QxZH4t$bF78r>UUvc`_>l0^ z?Bb6;#4cV~zvT(6K5(F2YiFw_IMo8t}B$UPmuIZvHA3rk}h z+Os?R!UfFQOYEIAXynKtBS%L>MBey9v@za%iA66SD(*pFuIiAW?R(NP z_qBQme5kc`*Z%L#e{bKl6JxA?^V44p88Fb*G$)$3Oxe5p;h#K(0h_!MGV+)e<&_~_ zEq)v4$pySll^ID82udS{3?8}k>rW0H*;iA0jV({Od(Tc>oW2>>i+R>KGIYk7$?>m#cQWs-t_@q6f&5v`e0Sj-bya0G-ijEZFV<{gwx+-C zo?dKFZ51{kYH7wghP>a7m2uAC$rZURY&$Sv!UVbpWAFur?_oq9eq;aZU;j!@k=c3i z#TSts!-TcIeOuRCz&t_ZB4OTP88LHjpl55UOe=(T?(7r5Z@kb6`>iWcH8)pRm;mTu zW<{}D6L8>8r>7(Bn|}N__QUCsn0U3M_~7o{7_y4-nE(A`vA=nSrV5jdz8C{MVl_JC zB9Y)gRu?Y}ctNeIQuZ0ZD=uC>Q&Ljyy6dhrYt~>c_LM18kZr{3he!=z*dxC9FkbS@ zFTX@OGAb(S(MKP}xHFcHgp32H`tGd~xk=S_(rO{TD|&PbCa%t&JsYzu*#JrG2f{aw zUA=VMqAI$H$W?@CkDz}e(_%K?Yt-xcP)VkSsM~1diDWPl8n40@$rxdcGdwITZ0fXW z$ek&++@*@KU29jb#;SIkH*TCfc{1I91Q&3P>>)BNOpmkI13R&prkB4r|7?wUChd$!2&BMjs~G5izX6c`3L zZtOVRjpseEproh}UAlMQKJ2WG1vH9Fu3)3amw)?PG*;m?e+?hA6(Bw=m2C9>m%F50$j>l^?5FEVTWlartQ$xoAdC*n&7&4*^u zhabx6NCbn2d_E+h9Y_Jg9Ln1II*b8ij(806yhtMR8bJm*Qkt7LZg}V2x3Oakma2U2 zx#zG;H3GxeC4MV<@BR1XbR>RIgrQDg@Pb+_DvH}awgtxu5+%jO7!8VBJ+_y@Fh^W0 z7-AV66@|l!)Mr9`JkC7;ut+E-Yd17DKp7bo85R*k)4j3u$AX0mu}3-IH0laW2FH+0R)zJf@E^el!!-KM(XwKCP?4a@L%v8bTig== zviRl5q0y@9khtyZ4On3}?X{~`;bx0fTyMeMwE50)q$By>aLA4xaP+`!*#r2+Z_1r_ zhJ=Qs?)Y|JYDz}hi8CiNQg>{Rj*YwTp@%0-ndax`kDqwiax)(8p7`08Zzpn*DB>4Z zXJ;if)1XZHZW5C;yn!)UdEPXJWGJAyXYy9decSh}{`Bjj%LR}9;>qzdCSj)+bxOq( z{%_$Bn-8`wWMgT=94Q7Ca7&PCi9NY8Gg+3SXFeEGm9YIvRi_-S@I{UmD+k;BmjkMZ zR=bDD8K)7viW@E%9R}tC<_segd0MQL^zLULuiWrWMS1z3e*62$<0pEVlUzKk_(}=Z zFY4Uc;#vRz+rMCh1&=f?D>KkqI~Ybnq-`|)`www*^(ZRFx(L~!q2X6ciq2g)jp^7i z-Qv3h1tZCFA?w`b{G6b!p=DP~F6Cr3H8l)Q9*&U)j6(6zIw%G%s3%a#9|y>M6~#kg zFk2QRDLFy7Jm@qwjiFQMBvXFOBS+LK(sEZ8F!RFy05xJDO0%-*5I$JcNpWCyO7E)} z0i;wjy&Qp3r*sY(niNL=P(o_XoH?((_S(jc8}aUkjX0V87himVXuzLYoA>1-l+a+E1A_^o7!`O`AMG8}yg3)0R zFhWt4mMOsrF#Z$Bg`y6F@hTP-^N8Hd=1rS0b-Hg-pI`mr7i@W4qQV)7Uo|iY4r@53 zWn^HWSC5`bRt2914Cglw;9{DZs_M&^^O%gk7yyVpOfb0>w<%2WUD((NFVHjZIM5g* zh%*1x#wI(5RBO16?9RbfzDX-Nrknv{-2QVA2lm7T%e z+^~ct{t_5p3Rv6m;Qj;n>O`LOcp5ftNyl9p8|fhZ)SQ?&c<9iH6DABFHvFx(-$K?5 zI|X5d!MM0sg+ZCQZ2Cf#&IF}anf2_7X+!>g{zdbTa{s`uV|0Rci1O?1&a z7#mzL#tnDBIB`uJ=jHu*PV!=>*#>vqz+h8snyIFPMKd6p0EsS&P)9<&@BiJM-J{Xm zvb&P7!P#rPI&SUN;UFT64!v|S7oHCf2|n%MSUtK zJL}l-px_{Ubz(VoOSf))>*^3W8aQYWj=je2C|G2TT`v)iM2~k*O6r=BfV{{I>>!+< zj|e4prUV{e3bC=Vr%q+>&iEm}un@h46XDsx>%RWjLQ{$}B4s!(F0NnOeVmbKdF=Rc z>;U8K>q|SOqEWb&$W+!ZD3U{{N*kcN9%OPD7`Z^wr7gMLwTeo4X#xS+!|dlN<*>Xt z0u2dg(k!7DB(NeGqagt_%a$xzwR|~lD0Z4Qey~)Oc8egz_mEq}ESnj(ng^t#(DL(d zHLGm){D*h|&q~?}iKzAQ^&2#7L}GHv^!sLBy?p7^@nbk9`|$pK_?a3zu?!tK8XIG? zlQFO;0nrwcDLslAF2n-#hFAceR)x1*-p|CyBh_H~+E619@WNA8?&%9l=Pf>&d1B&y zlP1lW8Wa|!@C9zL|K7)E4H=W}=+ycQ0m+NrTd-;QI;H~8J-l!pCL07A!b_IhR&8xnd|bD%ut?m()M&iInG7QU2m$7cCh|6a)d^bVw_#ufmT)b`ys$HZ z<~s@5jCE~NL-_l`AQw}Vz0P?a4;P2Bau@T0n9k5#@e4kOq`~_2-e_O*uVe% z@5tAC^UXIQ5OFSK@L}eB>ZzxwAp{_s617;bjWkK*q2L6Y!-o%JiSgvglT9K!^RJGj z!XS3o+|0i?kq#pgez0na(CP5v`nUJq!|LhSgaj;#X95wS$jiO-(ZBz_e#J`Fa3q_Z z$O4N2EG}$i!D%*Ah~hlMs}bL<^q03Rht^r@Vp0zh5)`+8`hZf z86z{2NhX6y5dcEU6)Ez{$vPcF<{`e8b^JKOf=J&RHhg$=Y%Cc}G`J&%4;zNC5svEk zA>#*xBc-H8e}5#7VTm_3D`3k=(H3lzw`I#_vKFHkNnZU248X>eNGm?@40oX1kQ*>078O-d-Y1)vpWOPJ#3YLUsG7T31kRy?%th&nGH)>F>-LSjlhc{ z6yX%RUcGwn-@ErrP7coG!KuoyiZu;C8|>M;C-+h=&xl3kum=hC=PzDl;Y(&MqM5rg zGT;g#MqsgLQqrE?yAiHM(-2~yDn!#h`{dItTel1vI4CkY3X8JenKKvREhI}09yADF zDqbEQ_=i(m;7^1jF^%JKm)P-b!C<)51hu)Nu-{yAN-`#&gPEB~+#57VT06|HvuDq3 z+qUgSWd+iWQ+i4JZez+dwb)3*|1rwvQ5=j~g!Zs=C~!!gO-b!Vr>P)o^3;iw$B!OG z73|cw=&@B$#mi9#0HD; zS;!R|YvMQ^Y$b`K&-d-yj~^Ft&X1%TJB`-XLe#y7(oc% zUclHga?yeXxY-e^#^V_lTi`2Cka0mg`gTq+_wtWFo4;Bu;Yqad5uGa90#wWBsR$|_ z!{b#1sC9uFzO9%}E(uEM!2)W8DhXqFs23hG%o=nw`^21iADp{*VZzAKPfEwvgtOS0 zoDpa84CtK}D%-FeypAdlE-;4KNtVmamMh{gMsf#xM;w!VJS+3q@q^RuoN@5b9&B(l zc;IkX7iozQfDM~g9XPam`m_g*9Xo)4%fJCcvD}ItV%58yQ`PApY1Qkom1sr={w1R`-U5qKH?)1Ur? zw@oBFATEi-9BlXnUiOY5bccDZHVSDygfl$JA7}%wSUI6}jvoc3G`lL=!-%c@USNg_ ziaN)7!@SH}rD!(Og$%nLyLKVUhG5f(ks}zFU=?nJ0Rsnqxo{y4Yu>f%2W$<2l*72V zxXiPb-(rY+tr9?tk z*OSMymM#Bw%jPZGO%4hP#>T9hzu%mDId}f&UtqcPw5d~(UxxIw)vH!5Uc4AH3_?-( zoq)oWWIQt?I24D2r*GW20V`+m^9&ODCX64C)H9r)gQ)k{ix;!qEy)aTA8({N!y!nK zTfAh+op($P3<^4R^5l}GOSW#^3Ym0kBcha^K5X-*%_mP~&v|DKVs;2I;5+5~S?Ofw z?^dtEZxZ-+!0#40+1Xg2y=zwqc4EO3Z*p>qtE)>%amkPS_p^hQF?)m<7VYAvcNZte zx`sN0Eq?s*M|29bkZlY=12O00OkPAek08+b&f5PO8wq+7aZ~{)(aW~Ay z2vzUhyDu~}+||tmE2g(>*@ATg2qlB3s-~vmMg?0Kjw1D)D9%)dSmbPC*~EeU`>~=N zyA2?Z70t(<063Zk$B2V4Bq#_$cE*8K+vs1&#ef86XnHyZJb<9Upzz2@topi`n~UO* zp+m5pB>-HA-QzrT+)x|Vu3h-$ml%H7DibfyLg)}h78b;(P{kBuQD;D4 zVA`O;$^H6epE^-eR8(DEgK%VUa0uS+3a{l=-6*#Uk6^?kE&tlJs)`D%>5$fjIN&Ti zX<}EDPkIUHlm_#0w#NFVmGhTwTeHPC$nXBAA58Aw%gNaZ1lmt}yad4{ZAPOjOy7Py zu_Ti(lU%tnkSUtTD2dT3fLG1rgPlieNRxsqrY3P@h7rY;O`z1@Qjcq0YP^a(W{IK8 ztWw?{9#+(Nt*Gerf4*_>*wJqBU0#0r*WFFW*Dx(+?NO9TNS>W6r&@J@=WC_v+jr37 zrJo(j+=mRWUAs2d)YSIwmByB1_w73bIV1ZI?7_R=fkSwrsvnRx6#Kn0W?APe%Wu^r zLIS-|G*K&#WRnE}!6ArwUB*F{b#-1+A7Y?MorCi+5h03-j#f#$b#Ktu+9`)6CmV-EoE^jlHpjAE?v9!?%f+XfEO>EM+giV zVc0|xC)*&}^2VEQ1_cH>I5}bA3%)yWz6Vm%AcqYDgmMs1#MX)klkMBP7t5}3`(QT4 z!6OLz<0L~xjX*MoXj?8B_b5z`4H)ixB(x1x~dABU6xl?BE5Rn ztXW4+jJVA;*0GfdenU z{BlzF?vFk?yWfERMsJ#I_4V-`KVbqQyepS4-<^?x02Uswv4p6s>^kCc_`Pt1;m}_H zfB<1AVn(%gbomDcc1!Gr+`}tZE`RXB`zu$kg7L^{EiNo>);C9nha<}sd87aO_rD*1 z@=23{t$*d}m`S+V5vRs6FY7mKKmhT?sZ%)53T&8Tui?qxu*XHK6eS2-f?}{?GXNy6 z&V6?-PW?vs58VpqJLq&_At6}qjYLsTcaPamJkARslL+8z2c${s=n;I!w-G|AU=EM3}W|Px^%%W zh@=kNNZvhtIv!V$S3B?1Pu8wogHMfwGL^4L-g4?|0Hlim zo)R$dxQKu5DEL^Ba&7y?PPR`OfsW_Cz z0Y4B*@AzH1_rMw4#}6IC`+wK&N%WkDDm*nEJ#;WH?`pRmNrAzkio#9b1k!}BNT6tp zwMh29ylegT^-I@QRa8Cw+H-@(rn|WtmshC02s1705rRpuR5_BV+Hm3%Pam4Z(o{>4W+W9@yU#U99~uLW)g*K#pT!{`KLUj05}eO7`mW zFZJtVJif-%ViQ{j{z4S5l!uk1GR?|Vy^{MP_;~5cg@c)U@(c5md-g&S0NYL>G7^V^ z#veJBnR#U2=`&e)G*3zHgDwzM7^Djp2_&qDn-yft7+?f@gd!(_jTQNb)$8LEyCQ)a zzh1z`krO948gh`abU0({R^+*)r1cYar)5H2+}#p-^f-3#;NI;!`wt$1d5PioY}m0 z;kjEE`H4t@$Ju4Ow{A;G>(AsMQx!$D7hdvW=VJ&86x{_H6COR7fAz}$zVHHKzF%$J zKuf3bx_;{T2}s7WN}PBmNHxwW2A>;YkwU8sKO^AE8*da? zp!(HUUm+6u+H0@D&nzO0ZPbw3Q(0Nb-m8i6KMgQMJlfM}^e?Cm;=(8ti79D)6e=!|-OMSwbK*L@+!cKLFg?}2}q7%|)2efQn5F|n`x z;WZPh7itm5>Dv!+w*Ff;Z@!Y32MlaxBS$h`XXhyG1R`As zU(UGiKI|HaFd`Nv<6?N+`_t~4_Q0$M1`HU$Za{{oPv%b;fb7l(XU#&)^y{y_#?F9O z^R9pfJ1`>rI&J!Nq)Bex+`7OV0mrdp#uVfi;Kv3$f1u~F~Z9-5|FdH#+ET+C*QX&mP!v(R(ZN#zHNZ`%Am^<&2 zdD8Zp=gwl!Zm5H8_{7FfAmdxg8wNNuJnYfgv#~WN)=iUDz$5t`E6NdApEzj}e8RXO zDzRI4Y%z&&Br?14#eqqT*%^x~CKl0WM9dMgWM5va3c0JJMvq2H>^Do7qM6vofJwtm zfvm}gAAAtC7>5jrPv%l8S^~f@?MqI;e8PK%(Xda3mvd=r;cI!xJl-mpjFekKcFL<1 zRhVrN$pIz7f+#gSPlhccSAP2~W=jn0*|TROHqF=Gval^2dCPW<%`lq>Y#xw2fcck& z7X^3Vg_|(NB}KJW)p-2!^pt`=_z%0S9659_`}lFZ)FCMend%TSAbr@Tb!&EP-n4Y# zm)<_Usr}N}cXmU4!ZF}VOntHfOfm0}g zP+Ku4sUu+K@vIN#eOXXc!h}$-vCf@w_mnOP@px{W`_X)?98T%c{ojB6V?ul^bd8Wb>e&;Vzr)-O#Y^v(J3^CQPkCd9||N$G)bz@tzW8d{+4C2}f`UuS%4=$CR;=4NZPNJP{MTb4fdPsw zrKc0Vr^!Obg24N{MGIDMScgaI*Z%PT#ts|l;p!&P2t`F*mLo|Tvx$ zw76)&!jCGe%I_ZdfQPgkTe@%vc>lDaCr=*x;>!<9N{drd`UM7t;K`oVvQg9S7z7hy z#V7RfL)V}l86~|34A{MG+eiQUSNG)Ph?r=gcWm6aZ|6<~Pe+a)hvXfoxP1Qnrtj7? zHa3i#dPf4@)*Kv1Pnvl0*s-Pa=Z~H^q4&Unc&R>rI%mO$A0p{)6v9j{u2_gYddlQe zjAqirUIPc<(#@qpAWF!GrG*X!q>m2b!i)<#W*TzJO zT_bUdC{K!Bz{Mh-$XLW%8FKgDc;gMwAhM?T($2G+f%&HKuykgHeYC;nlOEOOhQ7diOaE&WB}~J%12}gAxRA<^Dw;#wj$)QY}qorKVwf)G6v^>V>bY#XaOkM+02JS?Y?&wc_iP2D^t56fkm#gGaT4wiPvYzK>mZ6u9Z zMMbb6vkVa>?B|G)85b9)Ap|kOM;?7NeONkfUMwrd_a;UTW}&<8yc5fHtE#H~1N@O1 zhVwm;CXCG`G3J>-D8&s0;@Ptw9W`oXad8QH*1^FZq=?xC`1>Ouh>aTAQf2AZ8#4yM zXpF#k%nybO5i~^okV)68S1&}D5dW&IsPy&2+L~ajWtL79F&s!u&SzYF7a$lH79KwB z?t8HC7W-Y|&TVRJL>6R7NN7`26lr)#I!8Go z0%>m82NSu62$nG|kQ^QnKJA{n5m?2dA6QpaU5!XqSXd~w4h#tkW2ccI{De6BZ$d(* z-E|j^tSP@yjvLU+(+f*<5v#?RM|+w52(gVEH8MUS9*3D=N+{-2}-YHU<|O9 zA~Ew%ufL8kBD#ae17`LxWOP*Y<4-&>a^y(-tW{Z6g=3!4KsXRdg0IOpl7pBXcDY~IzY@H}2CV3%ZA zL_}p}6;^+vyD<)t%$v~AfUqfkF(U0G7DEn$4Fj81!EZ>qeB!Amamsvt{x#(MHZ(P{ z^Vwk-4y(az4vzJ_J<^1&Yt#4L8y>6RE zn+I$juzA1|4^UVTPjCeDjMxum)i+DaOH2LmP8<+`4HHXCigPbqxP0M!ul@s)d-riO zyqICJ_T;;!T|9FJr*G$7x)dE37Z@CZU;!Q&L0p)BZD4x(kP#yhspN&L!LWp4TLO)u zkphiI@pM*B?wN~t|IgUCGw0YT%CzFYrC#{!Z+j2vry=C0ufl=O<)t@@3QEeaSKue% zfMEZ~=uqs}i6_XCqH?Tr#H%_U<-NS!J-j{eG)5NYUoC2GY6%YWb9Qk7S>ZLD_}qY6 z{K)R<>F(?2g|$xfnqOU2TUt`q*w94m0Hi}@{k?gaR6Bg=!05>EhB`^$o*y$mS+IE2 z@WFRY8SiJRHw|*I zcS!3$^zbpHlOB|IhwMMp*~ORw44nf93}5oi7e@^nM-CY@0wLWF7=qi;5*HdZfR@GA zG&JIteCD^W9NM?<$bp0Z^}sA_Om}CvxIOshZ9I+Vjg#3>*Y?-e-c9ES_R zy?Fn}={HzMkCl?b%PNAE2qNRShNqu?8fSB3;XD?fT5VGY7_i6`4Z$v=d-m)>K!M## zh>fmXx$?sgKg9bm_85(gjfDp33oauAn*bo23_G3Sq}mRsZ6_;U)mjrIPnc4QUEim% zu@PRv;|J4)_$df@rNHCpaG#{4mL5IOG{hMZD8yaQS_fiw0fSiru`&?AemoUms6#Sz zku(NCTAMRW4E$0Y%oH&+iL`?XJYM5pb4v?U>s#~)@S#yezyLfBVM457>9u*l<^iicpcdxCgK1?&Mb@#SSelKe6fDfZGo-t#E8+t8J@|0X z)LwYylY$KPy7bXw@tgDNrAsq*Y`bvw41OI&5Uswp7Mn%Soc;LZyYK0d)Dy(8Ow?s~ z=n#-7szpe-=MZBv@hZ*UsnwcIwRLs1&5byF*X|Nl0-w)Sq<^$jT2q4`pV`BZoCeaQ z)goxL0~XC_eescAxe&`Sd6x=SFW#_k=Yg`)8}*G%(C6an^v8d^+&{hd!M(@7`e<2x zUZG*djJ=byLwwhm=?_lnJ21u7-4$+m|BX+pYwBM3<5O{6qpw{peCzd(E}Xt7sn*$J zCw31n_bK;{pFC}pPkKhta3>400F2-kN0ifY=`@ah~!;r;py!n1WrRfV^wcVc2Pk|IfKpVWR% zPOf!zQci3Cv>}E_C9~0JGV(K?`PKdO!!Q4MH}|atp&=2}G?PptyQygIFFa!kezhSr zd2GPYp?`aK?q~15cl_YN^Eo*0lpFFWKsz)Tqc{ATRRahaZ0N#TOB9gaNO<`YN_LL;MbDh9H0k zuqY4V%lF=U4;az|QAqd0pjx$R6@rqnv9UjKI1;b8Fl^_}ojCr2#s=gErDW zR?@qDXeeHiaRHzvSPMW4DYBLq^oT~W1W6=(eLC-B?QkT?r`eJ)1n~1A!2}Q)01Y8L znORs-xLGbnpv{X~vqq{(uoiEL!FZm*R|fv^#QfwhD8&RNZNJD@33D1r&#={1D#WO8`h)~$F#z>~vmv}Gi< zLhEiMA{$`yfXxFo58TQg=reHO|Aa@Dl$2CgR#ewizXwz zWEebbL|j7RtS6tYEHArJQGrdMy}i7>{rtisqwvcw-jVQ}$TxQ6xba=P_3&_ai;9io z@dT?yBq>BilokRbCXMczl!)XJ4f<{=J@7#5<>U3r-~ao`-#^>6XE%f%Su6s0c1?fe z-oc}X24KVU&>%u;$h5i&`V4YQi>}Z8>&M&HXLLzQ7%^dRP*`x)jjH^-0uK-ALA<=A zEc^J`gzj<4eG*Z;kaH;BUxFZ|!TuO|1y4xaXBvvaE}tLkcNQCnYIcjm-7Jl3U; z8GzL&B}Lb_t=skf8=nP*1`Z!P0BHkd#bu|Co$b;iu4hUYCs*TBXMcpO&JNOM-a9oU z*gq^d$lbN|8COzm?&6Uho-9ox!Y&~u&ePTHsacOq9z8xVz|Y@T zA%Ht}#K>=A7C2%pPmuEBaY2RkMWf-VJ+ULC^sl4 zFS?f3tk(x&u~&FleAh(dDs0>jc6Nh?rH4j@Kl$PdI9UZW(kemFBqmZqklY>~Bdr2t za{-UF07fhfs|B%kjA(HJ01kyfcnhgQyfd*o8}c6!eSBvrspdvV2T6P!PwTl$hHMqL}&>_pVZH3uAQ>8;x`3t|)r1a^tWXonG z?C^LHheSj?`s_0!#*Ia;D;D4R`TJvSet1+ALdkrfmgnjq45hKLu?UJG5fbZ{F`{4< zlA93GLo67=G3K#j9ajIHI&}(%4Vu10MOo=Ac-9$*R6x8 zYO4A9!Z?emk#-)QO<25YZP|ihij3CEghoNEwhis5wNSp?ax{=3nXQYM`z}L1L?p$f@GlTcxPi7Ak;Ud+`mM;Bl-aMRn zhz-7H-#%MLlEanX->BLK*gRnKfXxH9jtAJ{Xdzq(^L%~$eSG}7>Ft`3kJQwJS8;3{ zj-NmI_h)JaHfTfC0ROPRBvxvJ1~Czh_sozGZ2DtNuY(Z9;G-InePKvk#=xCq1)?0< zS4dJ_`SM`i{a(4b#Edv~I=LU0e30j~hJe$-9yH z=<4Q-eyFLb_YcGZYU8&|+K}FN&YXa}o~p{K&*y%%Zu!=8Ik|~F6A;!Btb%(Iy2ajg z|HP2+pt{<|==ktY{=G2!*qO9}DF{b0;vvKPO}=*=mU5#A=ev8l;XFz~j#802q_uMA z5Bu`2<=;R3&e-V4k;8@@Kbd{#@UgB5aXzhQR_G>;8-61D^p>3&LkFk1yGeuLhrI`{ zfn08j(I2J5F}*)+IDB z=i7zGlkYOU7^bfBeJWpLyzKgj6CTqq!SdH)PPrFFsk>)FkDBDqi7X z5sd!re}0Blq6zVd$Xehi(K6;3b4}z%ILF=!x_Nlu{ADb;udl0xok*-R+G%3+2?!WG zazrbG0djP5#@g+KZix-G2%GAiUERdVN2haicEPz@*ltplh6&~GM~1<@$&i9p%ZfOx zMox%?M>!IqC0qg?2f1ST@Ss71aH4`xbu)@sh26DhPn`TyXNbbPN;!j4V!QFu>HLC& zhhqCe*?c7t;9MI2fT5#Do2Xm=q7%XB7aW91tf9URTCfTg@gb@t0ogISM5G?cTf(h^hCD z>pLi!UftV-m+b63yxjc#eY||Vf`Rz ztB0o>awnUcTS%?Dw?{ybucx=CG5Dg=80s&R0_qT>BZq(34 zOO|igynX!W;aoV7fyzMk$L|Mx$yezk1Xwq3iA96jOV>*M3?iHPJ~ zlg5ooAA&_dMzf#*8E)#`X=;vgN#tO~0L zayz@VVdwgtg#eZn1#S zY7%QKv4+-+A)P3SX+l?-lgY+8Kl05ODXZ=XFE31OnxH$8C+;kzyi5z=m`vtsvd|g@bk!Yd8@@Ylw+ea>tpOqBlFk`5S0J$hd3}cs9wbYk2UfR%bdApsy#jdH|&cUg% z@UT-vx=XC;x^+Q_<(61ZDX5SaJ5dx9&ih}k&T<+xUTkEE9B0J7kwh#%1V)WZ4X@UG z)(U}37=%I*#A^p%qNgV#*=3VW4R^`7OEn^a6l$*r-$mj(*IKF1? z+i$N~y}GKha^}pLI1^7;nqZZYKO?a%9QiXk-=^B;0hSZU#FEFh9%UT3k|JTOXGgE?u}fX3FsX zL;K*|8zTPdljR#%?QnFmXF-YA-+r}c-|pfmACjRErgd%Hw6n0N=vsc^pWb}S#Yy@? zJb&?0Rdvmd-9Pq9O$srbjmY@1OglO}9EW7zF=5o^of(D2C8I|S_4gHWKqcDy@4aj4 zgwdW}(n4~i&Uv`HM?{2RpGL$YFJ8KWGq_`-BI09X()#qOsH`e2Ezdl9?9Xq`dHcgp z$BrC^oz;w8)3%3%2N$8pEL8jCiOsk}O@yHb)hpF%c+Eru6BO}80DuVEZ7(lP4HmRW zMNyy9cvcf<2&}_)601sRX$4IX+A_6kEmhW0I{YC2`s=S_#c@ZDB{5?JIm8T|tokOo zh}px0x7#+Eu|Fe1n!x`fFDz|SOORw95oZP!FIh6Xsj!M@#hw6?w?v>3ilW4}Wg>2= z$*;@#y167vUsquyX{>Fouhciy+q;K41&e!t@M3slc)SKodU+a4{6>uJP|EAlh6OFG zD3}MNfYFf2tYvcWDJ{t{Y;LQ;qFqJJXEN1@wQCi*z`|S0!;18ZNHs2IkGxinQ536z zWD+z8)T-LyA~jJwo^O$08)+2qJi8ifAq%fvO_a5GGFy!iN;POuB{!bY%Neb;Urw)K z3G7#^mCI3UEW?xi+GMO&-b_ol6lnzew{P3FY2(K0rKN1kNNnD6+bz>(&9sFh6}#Jp zw0Xehfm_!D6l0=L;7?#H=>h-)s7R`17`&cZ^o9x|`8%K@l5`1xRm9+xs}lq&6b(AQ z{NS4(w(f7Jt7in0?;dmaLwB?^wVT(;u)$a^Bh?uW62Q!_ZB0%5}ED~x-Iq*zBk0GSMrBzgDh01sqrQil(m$i4J zC`eY;HbA>pChX=RaI6Y8Ez($MY@KTA2{e#=t1W`aeoCK^E&wg9yGdwUdTZAH$HpRA zcOXF$sU?aDk|3b*NPWj{cpeW7Ulb@=kq8*EY1U|EBm!dZ2q4IZaE2i*`sTW(l9SB^ z2lb8hZJrGD`j*No4VSk%2B#61ya{Ty33}DY6!G9?<`E!}3HAuJq#AgV!_zP%k1-3& zS>zZ+A!LkNTNJr0L24VIU8^CVXXIhV!kQz;%LVf7C#W<&AXJFi&*3ruIGuIoAOG+D?ceX|H6SI@7>?8p9X)W;y<RO((*a?vtcu10d6U$&D!z%3RkOfNw8J9N z(w2wnn;u#uf)Sy|goaa*ZZs>z-852ICDNuPayLq`O2*2{}g3;}AAZw6i%3{XM04{L~^R1qoAu;Si5dl%yPnj=T>LmGYw9ybnW zd$!%gT#@ZoMcTrVw`$MaGIrt2!SL|#^XJdsG8Wkk`Oorz{40d=5Rr>UaA!G1QKnpr z9Aij~!`HB6FN4jRL~q$iq+Fdv?NJW#*NvT zv2VkcT_cALHXx0?ozlHqx2|2b?c8H;uS*}8hP@z-ENvT0atM&dgyEVo{m!jB_k28m z(WNU_Av`2F;Cflb=B>MqXJwBYGdwIb7$-{8;|l_y#Ap(r{FhjSYRRisJ{s~Gd9gg6 z*P@2!Y1Coj2B2ms(a6P6l_i$S+1rItyO%8Ot9B{1F)9~El_$-Fdq z6*Gp((KppLR~N`>j1pQJ^mS!`d8K?~c|0#fz&E$FWSuy%al?jt?zyK+mo7zxMZ5Rx zK5^niMfnYHU!NY`dkh;sEIiEkfqnpI&T!I7g!ckI}K<8jjZ^|O!#WXw6_YGM=`LPY{VDd!?~ z0vL^)lGn@2B*u0C(89!mSFVvTou!VqP*EmnBuY7k8O-oxJkcuxav_RjmLRo7DK^Ry zvD!2wR^E2r9v)@_iIL=kMvkP~pBLcmXDt=%;R#qnGzeG~?=gl!EOPGKZ|&Q+58qk8 z``z#GE81-xi$tSr;YeW^*ovYDaEzm?YwMP6VzsSo^T18>0A_0KfDEEUiqXT4SR{`# zS|W`ibd1;_1TUu2yQyIVL*sUf?R!^k9~YUe_#REQRsn7v?r!{vZ3fXUZmvna63=9x z-??RfOhQyZh#}1oRqQfZ@95&>=H?s_6!qw@XTI_KIbXcHI5H-*Yf=|7!00R!AK@#? zsv8@dotzyIsw1RDfHZpAuwegqbO;V~1{xWTS*OpWrX+QV$NK2jSR^E;4@nCR3Bm!{ zm#^fta0@Wq+*~G%9&sS^SpUAMF;S7~nT*oNRkW~Yx2}Kr!^`uRe7olREn9bF;6Ll5>A9w?1Ev|@M27{-m!KxlC9LaOahAS7DX}{&_TllX}4_Ath+$aiC$%XFst;I zb(hUBi#>n@N!mdY!LoJh)@es1Efy>hS~AOM(4b0~wop_QOJqzMBO|1Nsr5tORNZpp zVr!W}bo(?lwA5S^YMCWKYFI0<6c!fZu!U1OIZr(C#EGmEZ@=^QsZ*!$b&S`G!omW) zVr<*C^;f_Cb??+vfzfgm0wXJGvlkQ;F8b=LK>q+1juc#^!CPKYv2M-UkLG=h2fIlV zCL$aOV&-QJE=3Co$=Fe35(Hz2Ob%dw>gOD z7nB;ca(c5EK{ccQ3A~}GSFl!~xpL*o!uelp-MY1^qGHCq_fDKRF(4p7V7cvzws7Qa z*JUBPI;-4{Cd`3@mJW4IbF{%Xx2y1_k+g`#3wfy#KGidboSUL`Rs^8mR;Q{NMe@ zUn?qZB=txP4+*(r;+T}A?yMiDO`VXOl<4p08x|VExZeNwUt07n(NW=G9y=;MJ~k>U zJlxZRkJ?IN*<+2fv*V~?gAoADy?iCVs2E2BVVBIv(9k$UHzUGZca4PT7Nhah1}4F1 zqT!G5jDjK~CQ1T$E`qgd8JVDhco|@c1c8Msf*Y5H21W65q#VO3Q2;HYHw|wE0m5*R zQxGp3fjm8pNkS4Jmm^?EEW-;$A|@4t`NaTgO_nzhUezYnu9XW^!*rG$HTgFDPwN33 z#Pr&0uR-U=jT`^<#v1@EiAY*9_bSdiib%BTtwciK)TvXa zOq%56p%bbPd7vMgNA%TxS~cN70CoU>+98sa%m_#Y09-sh!|+JH)+Z7|Z}Wi71NbWXDI$`J zp+WSKoMk`GPfjw;#YiXEzsp6Y?3I z5|iQ$U#xAc@bGfK|MBTrM{-sz{yr%saqQILd^H?}XN!|ZPG8N<*Eu*~t4dcF=jpR1 z&3N?AfS}fJWX84wnLlRX^c=<;oe-Vex4RgQL5*QF17{EXM^zYk?is8Y*VZp&5CJZqG!=5l)U0u2+#Ky-)*Ecla zlx-X}?&9o%%ukUAtG10fwUP|EK$6S@Obalq=QX&j9~!(KBdj4`fw~oCt#Xalpvqb+ zEThNJYTwO(W-?v|;FG@(O(V9|86ANH3r%{tfzuJ`w>hub9u&OSzUFprhqlj#|7@<# z9qsIh?U5r#Vh72?nVDvnZ1cX?m|4`OJg`_bg)w?ks-cl+01%fjlj`i9b zl9Qc%^~x0>CAD^T)zvlEuI1I#)`o@R2gD$@HUe9$R8&@APfLH(_jq%2b7e(Ei(Vh# z?~kw48#iuvdU?U@YuBy;25(qc82uzHiLI&0&$|X8VIiUX$PsL2&8(}h&(F`lQBe`- z?;jEx>f|WBGDE`Et5@?23w^!4lX~({U6UAtCaR~Hx%5D^)Pl@Qbm zOc(15*ic(ri`>TY8#h2179Q@*t<|Wmtjq zD^c4eG11+_W7GGWuw)y{*&vy34(S!}g3Xo}$$nmo3#Y<}PEfj7c zJ4qwJa^K1+A;j9NQ}i;WYE4$dt8wwI(B{={>`%}%7{FW0+XKth3|RQZ7fZkXy0oN( z9i`o;Pak1*7cb{Ff^Ff*+o-c_rrJDU^MJAk)FP6Uw`gI|P+7VrRu)DQsij9=rB_zr zv7fOKq1jhy&}-8u$`_53mB|2EjRC>8g;S5nk;<`i?Q*Y`$zv% zS!I0xaddQWb#n;{4GazmV9A3+#|(NuE~2)MdxNsGb9Z-*h?X{*Y;(13Oh+C~c(H;U zp~xr|pCtGv3~adR?&fMr?&QU{!t@fQsG^mwEyT0{r@^K64AfSVSW;{h5m_KpEQ%>ix4PaA6fMrc!ED)7AQPC*TiG!a@6^i*UgISe3O0{)j# zWUaHr54<_Ag^UgchtXX6El5KvHf;OSZ^efIp{ETLa znT%ddg_;CqE}gx-SGYq^a&z$siIMILlN6ncms9v4dyfbcLgBAsqlgZ*h|2BRw{Lj* zu&C(hoSdBImgc&;x{{*eL__2e@6x?{_x``X{H3Cz0&z(&R#sJEGsMM<78RA0Aa&8p z+j~&|0ka>Q9T^!}R9v)V@#1TF*Is??wcsFWT?GPpixw^_zfu0j|M{Qn?Cf{ndoMLP za6RZvI>SxW#S zo0$g>ez|ZV_RUB5&&A0pDLMJE#~<&On7Cu-&cDC$#+9pA9d!1;`|WSVAk z{nFANdHA93-Mcq6Hy=HE^o#lP5v6TxY{VLr(2$V3rcJx=zWb0K_0__K^-WEE`}A3~ zXc4@6$CN3(diBCsJ$?3UYDx;Vf&l7$!gL(q+tP7qOoH}UQq#g(XSqFU^w8cBGfdoC zf6W@H<^i&emrl~uS|kvw2?v>?e4bqqCIb{}d3C&nOqVvVqQIMq6zma-?N%+w5sFMK z?LWweTO|DFmtA z7$PG^Mz4m~me!71_KwUy|1((Q-I;0O2?&<+QWe=Ir=cok01WKGfmJhpfz%|_#myx) zAu=W*vbhPb8_hZgd#uZrc4@S;3kvneKb{ZI)`>m4B$`l2Y?MHQ#qJ*Ny#}^@p62!H z^}%5Q_-A6({3}I*`37hbcTy#LhENm6*D5}j@Qf^X!z>xkLa;WMb}b3v0p>^f;}K|h z7sKIgeXhvC1k0^pk}StFR(oNLRWww7k15Jy16Yk%4X-FlK&tRkAX^~+ zWfm=TFC#PS0wNQ9tA;l#&_;0cJ;0{Q&eDUgC*YM zF)5PqwZe2Jr&pUL<&<)aT)4eksB71I>g?>Kh-BL_uXA*A2uya4nW%G6-9BDpwq;d_ z36>NWS5;NTCB!>AI%1Erw0>#x7tH_jU;Z*;_^`fdX^Dx6kr9!@h7U*34a=XgW?=8` zJ#YN&Z%BtkB+<*q2Wyb$f4Kl@m4ABu^@jTTi@CYyFI;G>uV?a+z=()pQE3SP=Rg?0 z*pUv2{nW6LV$Yt*)22@^C@9#rb?eGit5SRSPDxHacJvs;BIr4J$`s^Go;q;?>&MH> z$`TS11V`asKX&Zc-`@OtQBmRO(WB$y;&2Ynwr$&Ls;ggk@ui6H@c#Y#BW4O~ho%ol zEwRJB4ULWa4;*;!z4v^5{ifYD4WaFfJ-gSfUmq0{)AiTCCQ1OvphWuZ-`{)_A<^;U z#svljWggC4y?Rw+ecfyS{om!~<%_@m`rv^Bg9i^uNl8Wsb<^g}kOS2Sm7Y0&KJ(Ba z1WchWE*2ZaL`OwNZQHyRC&D8fDeq;bR}Ifw)j^nbilRDJrC|dIG-|CiOj{i(;tkeN zL1e_ph}G~KqS{x*^pZLVVX$Bi;YkZmK(L&bs)$ieLsg1^K+lQ-B~Qb{0-1Kvaz;?f z+odMla+JT38Os^s@UkiBO-o1;y{PNuC(@{}Tk3II-?S7QD&}3UUQgDc^X!C&0 z1Gl0F+QvAvEJS$GEU{SvEd^HAX@>^dKL1Oi@hA)=`&3k=SS6Pu@PbS6;_=#g71<>u zryF0Hj!uqt_}7t(*~E!$BTTXVZ78Gy-!55Y=ODwg$tC)Zl|Q<|Kh$8_xvtmdfc_tsu6IB|Lwd3VDIQTlmo?j9b1+E~FzEK@k(h z-5rS0;L@IdI->vP@}ei0>x5z+!B7}GD>{+bZ-nnx}>GkWgXV0EFbEcb{8-kU- z{{A1l_nyStHf&K`QjGM>t9e&v&76tMM|V%pQ3VA7{sGIrS=Oyv_ve27-1KSF&YV5# z;pRT`f%}oXNyY#`!ee1UL2y{ugAY6~andA2U;~2!j~_jTqkDKVXe=%+-nna6Sy|a5 zvuEFT&%Ibs(!YP&d`Bk)TQ6R`2;#!x;=x0Py!iYJ0YQOC0!6U(=&@snO9li5p%%-) z6T5eR^x;P$A|h})x$j}`;?m*^7cTG}L>l=90>&kfF(lFRwTkpasR789f|y4Nyu4CI z$q0Bn*~ue8CKQQGP$4gZ(JZl90=d|BV}ynXa^$tJjWt1D#dsNkSvdlMT9bLawq8LX z71it{0&4+glA2h>=F3&^?UL7a3WijunX0ITIT=-mMb3Hq?frZ9;3#cuSvzgowA=kC zZK1s#dubGe*@`v~*gRnKz)kP~qK85bqqY*BSRQ#~(&TtGnT(Q`sYcmZcsv~oFWpSC zBC@lTnwKM2Z^;|JdAWF=()@z%EG?E;!qI3y<#2?Z2DUyq-AZ8*W)(e4X}1=)NoU+C z7HT6~)aaS*L}ngP9*C?vB6UEJGCe31M(_9DDvWTQ}RCz61vW+if z7q4=+)MZ=pzuU0nMtMbLb)}<|(|6ykNli_i`M}K4qeej^v&yVVH8nMPd3iXX2J5E5 zh43QMA)kBhIUHqk>C&aMr_W?%9Y1v7z_Fu84;(u9yWjo}K{^E8Vq;<^O`hW8C7l!9 zt9P#n6DI81wd?4S!}4&X&~|Jb33}M+;NXCj7Cn-ZR;*l!j7TIK=H#5lIYpBuPQoX; z&&$he)QFMatzIoiKoKEAtoz2|=*f575gs0nHhOt_jvP5+)At)so;(4`KK}j)bYh*wm?diLzu+>5yggQ8LOwc`GprKQ(TpFSNP5ixQ6_^>c3 zR@=94UnoVgBu=NndD+tN)7v5a^T3Zk9zJ{+du}>7IUyaAsf&(@xpV5&K7IRQ-)US( zBgKMr>}bhq`A^nUJjN{C4i-h3YBF@t&4Oqv|e)9^5(dZc0l_V@tN#nKlSV(j5*!?bfmoFGwKCa)*&X02*X+Mgc<_nGg~SGDIZpU8SuU z*;T7|+jrUP^M#)m78aOrcBRF|xtDS)tE=MT;$mZC8M}r_jG(ch0nt6g(LB7okcNm9 zzoMcd9)lJYU~IJVi4jJyikgxZ|Z96N=r-J-Q7dOq*$a>Yj5x4>l0$gG&CTj59S$! zP>V}SygWQ3!lYOv=t1n~>lYXlTv>@zWGaYW4nQf^aewpew^&Tw+|=ab?Nd`%%k1Ha zA+x%=8evNRfPj#&Fjj@*ZsOx{`k!=_m6d0oJhgK5D!ARr*~Q(}rLnmgZe+Fi#JM+J z5)xP}5=D16H-uu@8g|w|MoFDcjC`#kk~u-K3P`y?q7wGG1*7 z+O?V*43*%g26R*wCy=Zs3pzzONmBxZM*&8nh*aWM#X$lt_SnHb8Ag>sqBdMw6m9Ah zjT9bq+AtUj=FQ-E%u3^E8K}siXsMi@nZ`cla^!f1k(r=?u%4G6;R*sxD7EpahBxtb z!e9B#(g{L;CK2c#c#QS=O?4fF!*4Wq$(u`6%#2@>ljV)QZ828PlvcD85y_rCd$J|l zS%xLs*mjWkrEyAyb^-iM#6&TShA?Y2DjLZPX=x~HD#+Bg)HRi48TC3kRkrwK9=}jo zRfz@Kv9YmaPjOKZ()8A>UHklRenUuJ8X*b5Mw~9r&drUD)fJTp3t~^kxpU^+b=O^3 zG>zyRawZXK#MXz^RaF@o892NJ4#&!FL?jVLWT8gfq74m*#Wj2T_!z~u=y8-b(~E4u zs_H6+2Mac?F?LpkJ!Ek9G|sR_5YX2o2m4gwWa{8hDKnAp3nyo1PdE44`Z~lV z{rvrygoefj?11Uz?d9g~&M?NOD`UaNf7U68}Psv*SkDw0ot<-Ar=GP5-- zrz+;CcH6@gF(Fo2NCtpQ&MwDTD4ZIiEkV21tS+#V{T&p=bTNTsJQ>g_!bzGUAXf@9 zMWiCT8j|^k+0SB;TQ_gU4(vwKxq18#M~sGLl#cFx|60K0Zi1%+Jq<@yy7| zs%jkbjZ>r_f}a1`f*+7M%U) zquBehcc0!!TD*S!IPBp(2^h(~q=Mq^2)Fa`+E%jao$yd35w z%WHul`2;AEX(NTs5{invYBXd90gCK|RFnf=MZ6*sq6EC6YJr9fFbgP3Gs`6>FlzuQ zB1!-RGl6QJk*gyNNdQ19sP<+6q=kSEYM~}Dftak27rAZAmS21<5`ME~M|N=V*w}3z zuzBDo^#I=BS-4O!>PMy!8DSVeybKU(6^mvm1zigCN-w%fxHkID611VdsRh$By_OQt zbfU$iu@hx#S+7us2LgIHP^h*5MeLpRuwsEut4Yk-Mg$CCjUuF6pZ1Mupoa`z0AK+l zm&#*!ULFsauOjygG>S!LgSj5ZXrz&;C`J*f#=_&%JEVxk+GnldvpSFT)%)2@+Dh`8IN$&+#3H2&Q;V+PJ> z$=I`J^{SO|@$q9ujY5WFL}cXh6)O<7D=IEl6>FOS06+jqL_t)(asB#=6)O-oj5Vxr z=F=g`0LfIH4mJ-NI&|ar-+%h)r%1hwjg38i{P;KDe1ijFcuEEb1qTlumcDJ<*3S&Y z@ifPs#U8>OG-+vU&rY;iCw#r zy0Gx@A%h0(+_r7rM<3zvad&r*jEo<4@7Xhc!i4DP=p#ptlvh+7KYH}X{ri#BxpT*k zjT<+*y1C&PaZiudq(~;MuCBh~#tnBj_i)p?Z_?6nfbUtwwnK#+PkJ?~h=z9{Av*;i z$*dJ>GKoweV8*M}3dF2vw!MNhIlZE(ib&o7vjjZxO@-wxlB)wvCmKu|J5i?GW^GKz zTSUgQ=1?pWyF<*FG2<7U7s(uJX*jZ`vbMCS+{4pNyFYGvt@NUA-}*jadV<8>Akut@ zPn|t}FtWpT8u%w|9%2~s&8C+f8nd_&5Yqqw$XX0z096100>kLp%>osu0;TEE(-;e3 z#E=io(A&&}zy_3B46H47_6|Bn2R+sewCa^C(czI7DsVw8_#_ETyB>{eo!CtpC3Z?< z18o(?4lj4ea%-@e zaX~m70%jD^Ot_gj8L+Xv24aS1elS?njt@5vEcS3wdw^?)z@LW2(0)l30B)cgl?%13 zNMzczAR~IB)MymZ03cosdSh>!D5M@{qt1ls+`odl&K0gKjm9vld6-eR&meMBJJ_xC zFM0#FIDgT^z#}b$sc!|D#IA-X0s`cQQ56A}&=4^LEH{8>lyk9`GQ7DPfxIHvt0tA# zY!()T2t}}~X=jq<^qfY;vDA(+$%$+OvT@GtnobnkM#t?|u*wfK;n0l1H&>`7KQkKb zS~7zGUa}4TrF#Gy9scT9zrrquYuB!2OSU@_k)%0DZMspz%iS!@p>hIX)M#3&uhySg zq7aJ&kIpVQC}{NP(OJij*VotQ-Xf7Pew*ZgE&1seRsz04I4HbI(SfuPDVrwA2zH}dTds{*ds0$j3dYJh3B7N zxNzazcjmf#c{n*Z1Ox^~$Hvlg82QRb{K5-gE&A%+xpT258ev(SuJhQFPxeXc$J8*f z#^-ZRb5v9m!l@s8fFrx##nHDO?jDKVyLoy09L_w1xa99&eU)gDWjS)>sN&Mn72huZ z=fD1i{Uxy=eLz~;J=5>T5^rpyIb!5UL?=(3%7zo*pNSJEZrZd7hjAm*lWE86E|Z2G zIIrdBBMTMUNq%Pm@}6aqS^ol-6Obn@u|b1au}TrCrig3<07l^wYFUxUv}?%>0tDeg zQBj&e!xsgyU?;9vG(5~$+Gh`WsU7UrvdHO4yGE zu}HMt(r{!+!S$6NZ$5MQyun}cUNHJm{@Twf7@pqVjnuL~^d_T0kI0vUyBpqhTk7iV zku#5QP)(h_K{{&;T=vdR4t{<*Zx6k%qu$d|3Vk)`bydx}`ewbW1O7=m@zI2TEqYIT zy}N_17I8zJor}({zQwM(8Ts$Jsz$r&M!SYaJ)&&BZhCJAkQrU&rnAHEQZ;&AsUc?6 zq}SEy?P?qC3M%ZHn{}=(Ej6{xH!2O&hF<66rb`I1i*&d1Fa)9CHhg!~+UY7_5-hNj z#5U@6H7&Y^7Q3njySfHwL^xF%(*T@N+oDIw&0glbx0Q9xZc9r!Kk&9DHpp!nxQSYjQ{r0YQ_U@oT2uxzvH|Scbbb4&cT%)&hmTc3N zHQV242D`4Y*$#nJHwV3kLrG=9dw-pS8@!~~*W2sdTAbXQVOmq8eRGX{bCtcm(#{U9YqHZ- zHp2swc91&g^e)o(D;lLkE7&@S8+B5I&du4bwxPMCOkY*Yh6fIwvv+s3cX7cMEc*He z@ajT+^#Lw=XPvIT1>LH5LSPyWZ?+51qX?{^@njdL25%31MelQ-ys?rOvKG zYK47SldiH6{@1%WNPdu-W2g6Vlv*SOLv>O|SL0V_LkE>h-L1nJc~I))Xy@)^Ck;-? zd;0QfT_f7-s&8y)DZ9~9Q*YOT7^=?R1C7Oi$B1ufuB+A8HZVU(iVPz{@|)yE^B3s7 z{XFmf)ue%Adz+`eg(+AZP{Tm-3~1!eVYLj;V4?@0P0uSOE*@zKOmD&3#1@i~5WF?x zQsWZdsTr{{J?adiNOC&?(jo9OyqcjrF{4qVSHoK(nbV?p^C%)YCRMQAjEWM`5847z z%le8ZlVf<{+O<5THq11dr1AQ-`7EjJL}@UPmnIZ7jZquG7+|QPVWu_?HD$L}GHEE2 zz}C9#PwunR0^XECT_cM`7`n3l$xmZ~B_8Y}ASD*iuv zX98wdRptA8?tE*?OcI7n5Fir?$P_>Y6dA=JA}ZRp-L`ZiZOdnm-}ljK0)o)7Fz@v1j0Nd%rgm)fsnaU^Q}9d_gi~c?b>ywdu}CQkQ_qQK5MVF_S$Rj zd+)jbT5GS);#2QlcIdC>&)@ReYxm!OA9Ug(b?+euAAG?D7X%7=h-K#1o%= z;e}UUdFAlK4<|zKb5~w@;RWyCx^-(`e;=^mK{IF0oi~p}WAHnqKs-PB$xpuPo$uPU zdl!L$V53At5->kg zD_5?(=9+8fDv*v^v=?1;(emZX$w22+$e9PtoIGtRN(ebL&N2Rk`NCDJmchJr>mbKX zPM%B&Y$`~!7bGYLaPW&?`Vs*VX}%&BJeQO(R;5 zVKWWQ&rf^kK%x$hzAY z9CU|O)E)>zI_FHC>dczbm{mx$r<|hemTJzPl2dgPlX3!3%BTrJhybecj?P(`=1&8b#WjgBOrumZ9r++nxD0nK}AAf!NDU}R1S@w!u~yS0XZ zfV@yZ9*LgA9%B7~DKn1S29&7c@!C>id7I+^o`Py&@POcM0+|(&RpEgQgj8?Q0o+Wb z8V7ecN42M>I1V5kVVPib1vlN9s0};yvQr;Q)Ul_OaN2k+xbBvUvw0}dJ1FLMXn%HOigOJh4gO&cRCK^<8dnPog0YC{Ce0F=xP;JH;(1I1-f1aZ`sT zb5$qPUQcHm^;**H6>P^z@y~6Pozk9Et;ebNHWDCofT5soOdQR8vg1`E7PS*FYN(bd z;j5HWuDQGV-938%>0}8B&rGJ0`9iX*3wN;x5*=-cxibYP$3dA{F zb9R-S;ZmZ^61sx4PVIE_X$RDm$#(ZAN>wMD;ZMfyz<^sJmqcD1$!XJ)9CtljY77jz zrJ}+p_4}%Ho6G+y(_NxM-tgpq20iG|8O1(MqIpYR8YGh~CG(M=tzLeZ4LyEhpjy}_ zHfY$=1dlShrY%VYs`3!`O5x}>^5~4ZwRdL}pVenXuyq0>TAKEGQ!mRz_ME(Jja-P+ zZlc8ZCA+!EsDAUi$q zv>T>dzmUupHT3v9)`*zbq-oeAMzoCvcv_W>Q2Q3`(%#U1n3y+ge`+V7$4?COPfWHz z*#pG=+XJL*ldnxlQvxJO+5W#*Uw!0}plV}@xB5N9^A;0xPUiha-h4vdU5zTkY*Sht z7yeZH?EI3;j1s%SMs~7Y$1OSPFFrTdKV_buu+d=L*(!U3u^T`+fUOEKJ-wY3&epcgI;T)8qX6?K3T3-sk79jirW}t zU@$g&hUONwu-=F+j~+C{VYW^3Xs~bNsIvLvwDr(>reX7q*5fx`&pKL4_fB|5EtzU8 zY)LAP{d=Qu0&_PLiDR4bux{u@*&Bo74;qtGPp2i5MzF}+ZwHG!<&;w>`9YYxNCfDC9+Fx>Xe27h z_9Y|`o=ROXEIx`5WOhVH=CT5brZT{1lHd$xin2;WV)HWmRZc;uWyh_l%gPxwL3c?=-Ll7EO~QXrCI9+Pc%z#?TL4cv&SDHqN7Ih$^DWD`@s zBEd})5^9>22vnIG2q#7om{@aCxkfStqRBG`s|_lRF|rl{E)(=B-oQny7`Q57r;fvL zX_iT>mR`YLMNwl2vnxyiv{F*B@juJP4rDC!rg`^C%bt`=?N>39P}y9fIOHS&L|H`T zf5l;b1ih4Srg~M#3qUK>ncpLABm4W;V9i*M`15L&zXi?v&e=gr9N&<=;C%B3e)C6+ z=N|QfjK{i9x!%xm-%#cI)D3UAGCpy~jUjKJ<0*DGy>IfhzrHTsWH(HFxabA->ur_a zui*gYE$#szp`}Zg5}pPk`N%^L{m(Cb=||UJ@0ru@>Me8v<0f-cvvbbz*spKp-7lHy z-7~nmyZg*D&piE%GkBctXm4M5#1Yd~pkkm@1d-h5<;#vC29N^~FTVIH)OA0#w zw4nd(m3LJl0r}_GOc#(G|NLD|DG_P|Wx)Yc%1A8q_!(7xv$Ya%M(T#s)(jHW32+)X zfO;qooT-?qE-j}m_v_fWVdGPeKTdKth`GNvn&O#}W)37yKu@&De%Kqapn=x5${(mC zkcj4ebxmM3_sP(2{(U<>O+9Tq*HQE>u5DH9(ELV^k5Ge%Vq?Y zLL#^fq)%XP67aMOaLkJS5Edsbs)uD zl4(u@=g=f1Se!Nu_s74Ou=eb_<69bPu^4*9ZVwsuD%ZMBG#sz*o#+ zr9h}zVjH|3M|LZsbqU+BAww_0GDNwOCQjwTC$zJ7kV5*I&m5s zUM|Its{Q+-TH*^YyjUiN)OV>R^ijtXSWwLt$P5vN@e<-+D1lv?GfSA%Ps1U%GW=3b51J{Ab_++#R{81`!#8?II3QFUH4)(I41osy|RFjEFXUdDNQa2=jwY_sEW_^!~O*ynR9U zZEXcGi>1q!b*j)vSd3Lc>|@U%W>OwDj{%-r4|B}hm`ZFmyKP1ZB%0p9MjR2dG=aXoVJu6$R5$t zYv1~Z@XXY6Yh=Udh9RQiA2foQF2>T+i6dxZk(@FC7J2sBXLEi<^bo%{y?X;kf;Qpk z5iC-B#5|RG9OR)#-+E-NTR<)lK9C(!o2x_)y2C|zb_70?KRZcCnO=F}0ePVd>xJ_DFbm}Vm>1$Zw_<%VbY!2=5}U<4s=L# zBd!o&4I@cFCV!j&J&xK$Rk~3i`jS*=Io8Wb(9c2#aAabrKijAh{+Xx~ zq{(b>3m9>dKu$dAN^m9r;dwkqj3l9sK&~AMkOh7%svub3o zcM#9ZnUxr=h$L

+
+ +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# ASR_LLM + +The following table lists the folders for different tasks. + +| | Speech Encoder | LLM | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| [whisper_llm_zh](./whisper_llm_zh) | Whisper | Qwen2 | [Using multiple Chinese datasets](https://github.com/k2-fsa/icefall/tree/master/egs/multi_zh-hans/ASR) | diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md new file mode 100644 index 000000000..dc2479054 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/RESULTS.md @@ -0,0 +1,62 @@ +## Results + +### whisper_llm_zh finetuning results + +| Training Dataset | Speech Encoder | LLM | Projector |Comment | CER | +| -------------------------| ----------------|------|--------------------------------------------------|-----|--| +| Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 Test 3.62% | + + +Command for training is: +```bash +pip install -r whisper_llm_zh/requirements.txt + +pip install huggingface_hub['cli'] +mkdir -p models/whisper models/qwen + +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct + +torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ + --max-duration 200 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True +``` + +Command for decoding using fine-tuned models: +```bash +mkdir -p models/whisper models/qwen models/checkpoint +huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B + +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct + +mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B +ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt + +python3 ./whisper_llm_zh/decode.py \ + --max-duration 80 \ + --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name models/qwen \ + --epoch 999 --avg 1 \ + --manifest-dir data/fbank \ + --use-flash-attn True \ + --use-lora True --dataset aishell +``` diff --git a/egs/speech_llm/ASR_LLM/assets/framework.png b/egs/speech_llm/ASR_LLM/assets/framework.png new file mode 100644 index 0000000000000000000000000000000000000000..dc48bda781c6d01fb138737051030ee6c77bd59f GIT binary patch literal 853635 zcmagF1yCJPw>1a^C%C)2yIXK~x8UwhaOdLg?yeWN;O?%$LU4jRbMwu-`Dd!$`~R-$ z>eGEs%dxfBUVGmtWko4OcszJ8Ffc?JX>nCBFz5|1Fvu8K2+)-%)fp}@F!(BKF)?Ks zF)8R9n81>|Jw7{#+a&A}(aS_>yLReKu2v~|xED=$@HCUNIWfNFP zOnka%`i7#&$UrDJ4e_5qFcT?^%_8kahXe_69c%_89ezg_s@vT0t;~+c zw?|=gFdFqB`7~%XaOrf5@xiyAOx8wcd41x3urDG~y|s`7^0?d_95^Iw#-B{Ym*7@Q zJFU3W_g8(NTJtDT>|tOcg|O6-Y+)y`#A0`)%2cxuU_gYw*$gvDu`c)nZoCU50|v!8 z95Y{vb67^KSg_DhN5nFa!15SS#*)EE*inVq^)f;aNnKG|lK~W}q2{nZ%F3D1NPTgP z5h$SlCQ98lqkkk(pE&ZOtMftyQ((qlnaf`bCCTe#eLwm1Civ7_Sj0f*oI<6fp8PU^ zQXQ3}qjg$Df+>e?5h33o&VeO7G7bnFgTu18*{9N`xQ0`$Rmdh3_!)&J+)Jl~C2SHr zoY}jF&R~R5R+UO~>`bjql3!CT9jpGFJnXCSk#iXN2W}=!d^+#f)Y5gqI8NRKvsP&3 z3%1;Q5B1qglD%1bA$ttyRy zQ5}<6@mA}Zy?Ic>wmAXD#l^nj^s5?VP$Bx)U{M|?Fs49+wSJ^$5<>H+bhB88&^GB- zbtW0v855U8d@75Z(r=sA@-X}`Qecu-XHN&$b++e3-@d+;IBxrbb*II~#N_1MJqJJ) zfU5xGWQc7k!Itj9S}suO4%ZKQ8@ocm#S7xP;oK4)gWwv#xTND}`U#x^nPDMGNg#%T zp+sPlMEMdzhQIS_Bi@LTuEDkj!Z)CwL%;qcb%NvwQu`&)4p|qJ{fo&9GNk~!8zSV_ z7o2bsSQr%&ypk`t-z*Y155XEqmBz5)!v_kqX^3QlOA;V-k!r)>3(+bds={pwjSA4` zy#K&2!3hR-7W)5Tyn~YbX_kR>Hu(G6Ed$YYkAw#z51INWt`kxYjAP*B9^Mnl7Hob* z{ZCJ0$~SnTZ{Qek!y>qacFJ}b$e6(ZG0LLDiqr}Sb=)dwjd&YTV6j7SN(K0mB?p#* z82dbzm8TmR9~z^$Wr0>v^c?1&_dk+u_%B?Y_=1riafRRSSXOYjzY?Q|$P6%y9GXtB z{9qkr#b)(n?a4saW=IP%H#TShU(>dO_rPt1?IPR^6pXkSdiWmU#FK+GgLfQ+@Dt8S zr`>mpYfE%XXNysv`4T-ZlzWK(*ZZ}z7p*^H@7DmZ0O$aO-3VeTWjG5+?r*T7;@Yw_ zROh5EC=Mt-aD(EAWYK>_V@S$Sw9qG`?uVj>Mu$8p(8elDz)rbfvgOt?n)O=S5Q}?sj8%Wr53?qh4+9*ig$u1fd|7; zXVpD7WSwREVgul~CT;w?i1LOc& zSV>qj&Q~q*)s~tS;Qan7L>{X$#WT=8`oty?0PhL#ns}kTS94H!t`OA7E@oU|nsb=z zt$vq?s641nu4mcg);z z=sEKddx3uWcwu|Df5EawZ zy)iA6)jU-^y_(6L+4tSL`Kg9Llklt&YItChY_Mo#KaO%G$v)GjWBaBI)9%PD)lS3O zeGsqJqcu(2vF$P^Jo7i70$;7gh97$F9_ za?Ntb8KxPRW1nMxY_)W+bh`9cJ3jstgQ4DYfb+eF!I96zeRFoRUfL#VoA9>%mfALQ zdtb(Q7HwXKrA2&1V?=-PhS)P zy$ME>#9!Uu^ikKa#F5~Ut6+Vw>@XRS9WtS2;d6lAD+lWq&85xR-$Y=A&0NQll{Kv- z|BUml@fNz$9Sa~gBkK{Z5Gf)U;rNrhiG_NKI}a60M@vhX*{1j`bysXu#Lo-OQ_fFh zoUqcH)D54G$eLsyI6130|2PQ${W7!7*5;YF*!|}H_`CTxb7p{f&}DE7R5^APGX{eI zUVK!1oG{9#1`1gP>Q~BZ^eQyyFo8YIJ$h6rb-GjRACYoHeo;hGlu^t{0!dud_tZZq z(UHw!by>H4}wji8+*oypav)fsK1AEe*j zvfZ)mHLV;;PHHh<8U8d-=|R)YY-n;%9k1%v=;`?Ac-2DJDzE!+kKN8n;=EdwthXOu zoSmFLd;a+u(di5H_5OHwzdZB!i2P8wF&*_8z2963U`D~zXKbNopnub)VCZ~FepfmY zjGJ=8nYu-4OV z#%|~Kyguhu2wn&`2h?p>ceOcPwfQahO}rbwZrT{w*dqTSq^S!z!*G9xIVSA(9y?Zc z?NRV^|599d&B9Z|^LCT|WOPXg$G6~<;@9$#vEEp9&~gA3t9-C9^0Vxa%G}k?^*E{@ zy_WHkaa3s|Z_9JykNS~y?c%Uyq~*;_OwKefznh|m!!O^3=e(7$+IXfR{q{w!^4bHF zx#7w5HzotCHC#mV2{A_pm<<-}9bdS9iLPQ;yC3!u^2E1co`Mk7s?)E|-%# zT1{GmU!DW3-c;^4_U{{oc>@fIV{p&?$OG7KTL0wEbopIbb~_Tn6Z*W6z7JlH1}pL9 zKLl6=IDDeLmpXd=)!JIUaZKCQ?S=6R5J(bYBSRm9GA#y!SK(nr4QzV_D;R4F1u zzI+l8q?&sQ2xNu$j2gAPWOarvAYv^xMNMa)`P~6lK~4a$f`sXhi=)fU@EiSorvlds zPEx!v!@Y5AjZTjJCN^UE)OYu;grKhTY4HhV`3D9qkY@+PTEXnhwPh?66u{^}Wmqss za6B+5PzfBg3xVVRzhy~q8n7?_xeg+{5!PUk|4$u7(DC1v0NVfE^M4;93LyTg2J}Y3 zm;Wk5#{8=aAH2;DI>0$e>$rk}VNw3u!DUn_uED@Wz+}Wl)xE&a4WNA0M_1yz5a1|8 zNZ{Zgk+E|28N=67Y=*H%&4BN2iC4+<0!At6CQfrW*Hg!u1H3tmv2s{FAB zLjDi+{yilD1?5EAnfQNt)c<-BHyn6kT(^t}0{4IS%KugG->=eAfFJ+=EC1b>w9CLo zx*aM&=(eN&k4^cnr|5&4`oC%TA*_C<+sRU;8iTNaK)3V32!qtzf9m&ZLW=V(Xby#T z|94;TuM2I6At>6so^Otlsg;ucz){oFN4vCZYN)HHq@^{QcFu~jp9VkjJmkESVi2UNv&LQOns$RuT7;`(t{rp0BVj% zt)xf^x$XTC`cH)acIf-Sb6?rQHdRVRD>SNA=UJLgrkoTF66?n(!k4OG ziATt{=^#pkzseDN2z8PlXGr6it?RZfeGYluRZApkM4yPvG2Nplwk9!C zZE%|@R!N1I8*7x|v%!YaUeg)nwkL%o*^_4&v_ zYobR172Qi?(-y5dY1iYRz6R$!=j=#ZpEB{iSdsvNK!7RThmF-KJ%oe!9xw|rZ>n)O z#=bK}!TAl&H5dBmP#N_w=j=qgh*a&iy1ScM4vL$UPS=s4A6PN@td0#)l7`j- zK$8|9X*YE9m{Y7({hKAN&-6G>cW+L<*dK~nO<|82!1)SDA6#r~51LHX&BZReV03j7 z(m|7qCRw9G-|AQIlCNQzrz~4Kl&O8A28zLT#E^}yHwr8~)73?q*xMVwJUzKx?ey4h zwgM@?qaHKHgXKa*VH_@$%0&YJ2_8c*%_Qrs5uxE!?w2tF#o`o{q^L4zajJQ!_(FVq ze7yV}cpm;KPvIwk8ookCqYbp#{R?`n1m}_xGacNu(xoy<({-&2`~Auy{o zhP8PP(SEn782Q!>u4)z z{@^Tf3<=MLKb(&1T4z^xLz(my?aR)@Do7zewe#P{SY!M?F1={2vbhgsP6r=l8#hHN z-}iCwz5!QS4eTF-#72$|4b!O@QO5j@k&BT`zUV!}(_L0nLPpXZQ7siHV8~15f}3WV zF90kE=fYGbXS(1QI*Z4j=>>l$vN|F@x96lUw1~L2S705hod=Un^=Jxl=eD1iXf90W zry71Z-xwI&)=)`gN5am*q3&OvCpT#4V&zNXxu$8TV)IpBCHIxKsFVBu=-e{h93Tn2k^rstJ>6jO1ta>lh^_H}ee` z8<-uEp4{9`LRsPtwz8{8T%4Uv&!=Kz@`;n*FTPb>K02X~Aa@<%#xXR#IGk%BU?u3}eYM!B zSW3B+Q-p}5i0L4d=RoA9QWeHNjMNe)j!47P(M6_pczFKj;%x6|bVuE;%bhDsqp{dt zJeMY^FhZlx0(?tl&dt;F@8iA8qZ4~)y0H2olSa4j31W@5tyBmPrZ-DYV4-xlH&Gsu z_NokCDTT(Ox5MSaYE622c7FXHGW84Ta=Qo@r!MgZ9#x{t%}t6fXn@+W@RB&Y;a$_6 z$VttA%he4lAf*dwURQr|B|x3*YyaLqSKm9=m#c@8@g2*e#DNIPjsQd3&-nB@Ut@h! zmlh2Vy}Q`$BQF!8q8uI>mKm4Q;=bca^fj~OKlPr@Hkt|_8P~WDrj~z?J+}L%l2_oC z<};?~P**3kGa|gU3Wp$PwbQ*D7&I&JNbp){Qh^gxV zEFd0<`%%ho;-)g-oqdhk9c7y+9+*%BHa=$IQ72~x`PI$o*A)i|ll+UH%`J*o0FxV} zo}!GP@o?sCr?MGgd<65V?8d@yuOJ9_1vm+M3bbK5bhQsu(kz?}jtz#Bir%AT*Y z#0-R1we}%e%1_dt;pc~emU#h&6N-7leO~he?$QbtW;udC(i9Z@{|;l{pRE#e*{!WL znidxoVZdLOozZ@Mt4^*Lsy))@M$`Yb+LHuvGxEScRq}G?<~edw7{Etk>1S zrsmKdTeYj&CA30$lVilPC^{?A?L5RpCQ>B1^!gFlD10cdSvGiBk29ytJUJri1B~MN z$ZPCPr0uhtHvx!Rz9JZz4?qKd1sR+iAl==V4wghv|3w6t;?T$KtS~Ef(&V=!3^ONb zSPRh+Zrk5`CzSaRIzX%W^iE(Kuu9KHRO6cwRB~<#RW~_5F8=DFj+%|j`~cM1YTfkI zm^5n%RlQ*k{8v`tgv}$PWh{M@ zu1%T_bAj>C>&mnfyPSF(Mpa488J^LnqVC8kJ#y z_2LpUEHJvrUUUf_L>SUw6BsBg#7#!EJ($dVecuARF28*^&ux2drUbV<_;dCqb*y%n zE03Ug2b?9UtJmvNZG@ed#J6#O{B%)vQnpK8J2v$C$zpr6wI3F3G5Z!fE$9ERZmB)t zch&O<=r&t6gf&(mQ3n1qZp-_$HE^JJE=LBAcmd*K{y537BDJPfG!RH&rs=Gc;y7DG zm~z2>vmL5cp=={Rrr}(ulHrOaz3R}^h+nax&Tb9&5?3MmnraM(IN4&YsW1*z!E`1k ztbk|}TE3fAN&%;^$$~9JyHM~wzLarVFNKG7rbbl#6OtAtXdW{2NLy=R% z)~Ix^X$?FnHEn_!@vms(5kPZ_R%6nDMrIpJ08rHp#M5vw=7l*ul%1y|ghq^>NzL=L z5l4Xt+QbCUQEc{meV~*}qp>8-<|5?qVrfdKt79U%yu4|vhqm)NpqmI8?~UmwDiSr3 ziZ3bZ#8Y!f;3?X3mXP3I1Fmn5!h9B(XK>;^E;Y~Q2?sd*4Tf^Hm?raKBd%IZ_|ZPY754Tp1?FLoreVo0x>SLUi(4#SLqzQ{VZxw4a@Mq%B0a z_sZ_ebEs7KL{<_)a9h`lO{^;hqJ}%*z&fhrhUG~8yz_~X;aN;`2g>j*3{17^g+-e` z`^AKTocu%3h*Fz4nB{|UQ4fx?^~|tHs)%}$+kRJQbEz*hf)nq zqn3mcHj+rP$)#S17fQ?v6i+)a8xsz$4loUBjT(wv&ZusR)+OJUj^o61{U)cq^Gx>j z?!&j|+b)o~uT(3a$>o=lkQg_BjRWBq$mWCaD?7g9ng+B~*FwPfF`A@4^OyHJRWw5y zt*t!KCCfFx^u|K9be*&sdLxo(5G4BdhUaZ-Xy9~de*JBKQ<37K8q&1~469U-ovTOF z!>ok*=49*yEUE|s6V3CTw2o}lu)&TtRM33VEQK?%QFf}udZdPlq=AS0u!dC36L=n{ z&S%hXPhgrrf`dC*?EB0Ofkq@^3!MooF|O*qv-T^6y$P<@#O4~7I+(AgyegjSrbu%k zXSU8@_V%A22y;dB9v6#1BE=Rk9f_%^sCd2Xa-7OwCAsTDPx^<+?*FXWa!^$REvSX+ z2oQM1Eb7jtdex?^h_-n1%?F~Kx?Zrh{hBRL9SUQ_f@(OrM0D7}-iN_K}Ej$g~LdD@}i)>vaT`IDzud82qBUh!nCraRYwusHUL%&s>uZfEJ96d@G0k0f&CaMpc^UJ`<>YkY}q zNMf3qF-w;r0LO@aPMQ_!`muW1(QJ9K)B>?y_#JM=m>T1$i4Hfqe^MRo5F!Qi+FPp06F=`cO`ygR1_3cYucAM7w zwFCJR6;d|9Y7kHo3RZ1Jd7C>sHG_v=e;&|OxUq7q5ea_KXbwAQ!Bv)K{jNwl;X4Bn?pv^hhBxVXH& zPfm&HHB*#Woed9!o zJ5LKaNRs(m=;>OC#}`~~LkB&SWJa-l1p!M1M8@{GT+n!_Gc1Rl94ytL0(9EV@8sNH z@B#8G9_A1cglM#9P#l)YM@||(WM;*+Hd%h*N61i@z7r*o*bufs)m1$bmx-&3lAOLg zYC!DdO;Q(Nbj8}k=I;sJbL(o@uH;YQb!!n$|aTEXko$Y(t|MlbY@8A3T zySq@XSG3X*jLa5Ig=pBZT_C`hP7-cKNY3@A2oZ7I_(X_ib7k!?Zsi8qHy2lTcLdmL z`;%23E}x;uS1VNznZzE2-!r>q)N7T`U?~Jm8I3btZFJBSX#qitw7su*?LVI6ib^8$ zWoR@u-VRT1eO};HOD{d9UfVnS;HN2Q$9w!dHn4>9$$z9>E6>acc&nIS#J z_3!~IqDLzzBM(`l5-qfPC^D}hEJmcjg1n!SY@C%-JJ;_qfs_JOXB@jh_l(nUT;pi| zw9AMb+I#IK>?p}4ZcFHh#BS_Bj7pBnah*VB+%nJa@~y))8jC3-BDq~nRkKzXXgZ3N z<8r$o#~f8c5vB+-fT#ZysA_y;W21jgt0d%~?|}Qlu}~4rxYQ;6YvVrz=xD5hl1KJ| zSmvt1u_49xUVG{c(HjOLb!qgvx=FGC12axW384&t2L>a)*)y?=Oxl;M7jr*UyjCkl zN3GHtDy7^eqDn+vIbN#Z4RX{KC$ z5k!Rq7FwO2Szd6}P_9QSCljh$-M&gFn`G%?>FWN*!KSoTUDZ)w7Q{C=i5d)i7)&Hp z309UU3Q`_iCZf6rOV_8-)nqhh)MZPY3JG zmWzWnV?>EU{{DREiRb*P@p7EroV`T~8xet6rI?ho$<0bY8c~XSoKOG}D-K0cXhEW~ z6u?(e-jKq39E(v;a&&}@mY2AE6URcfPP-98W*n7Vvio7fRz*d{MFP}QAjpLW$;>r$ z6k&w_oHSWfl3c>^lFS7sk#Gxf5o>KTP0C_wx>-0iA&fP;G)YXBW=Ti?~%DT^N=&N@)aES54$2Njc-8UZ12bVO>C^@aF zC6^%8H_i8AdPvm898`-Z2675T!c_L<6Y1h;+X}|!3neaV~v3Gz1(!vIb*oy^bW8iLD5`(MVE9i1K22zeCKqN>3?jH$K_68}qO*c632>&J_qgWu4^?EyIW(`;=5N4pN zm|cQX!}=72?h*ktY!`Hv4TSvm2Yf#O87poDQjsMF)t~)8wKM9>MVzU*JDb0YP~wC*l5?Wy*q7|7e#<{ z;Dt2ccbF9=1AM6H|ufHxcoEoFY>q1JDa$ikY{-MjiAK}c4q$#NDE}oRso^rwGwgwWsEc5+klasQ! z9sTc@wYog79G0X16E3#tX?nqf*>M0j->eo9*|ozRB;xjNY* z<5(1_Uuzh7DTT=eWbeWYg67t`=6e$Q_oQgFk ziRGOPq9OAQBIfddcC24}`Qd1ZVSz;@8WPF$`2NW(QRX+vPBobY$c3Dum3gqqa^rLp z&t;R0){Pvr-||dlUCNNPBnTS~<(m#YYmIVcH>Eceqtm82(nm-R!)T{_{b-ZX0Kgi~n}M>M6cPZ%!VWK|R<4 zSeff3flae8^G*WHD>5>Z($mWZkN;IE88PQl$KCoS&52;+yTX(TE^4hQ5>jSg2x$rK zt1AH~8BU71MPoxhrooAz;v7^cJy;%kn3h*#uc$ru0}Zn0cnyU%IkiQ$8B05hu1Wrz zt-7d4W$VK+nF7t+ipqex>0kWSZ(?8?8^AkN)q#P|xZrI7pWP}0Dk8Kavgr1_5jSw6v4P*M6=T@tY)+JDY+x&fMGVH?RVVr2 zY(y1^ShZ6NJ3;c^G#VOJ@5jhi7{aLFwFt;0Cz4AEJ^czP&-eE^7>QMvNBM_f<|ru5 zbaf40Yk zva6+HX?%TraFm{AZgz5Xa$RSyUh2de+fl#o^6bipnqh8vI~H%MC-DdCfvmHrv9>tj zCy6j9LZ_OVpPiqXcXW1zk8^&!Pv-ZTj*(uHBjhQ>S=`s* zb*Za~I;)j}wbKU#*{-R^$;#^eB7St6mr6O16H2|X;;9R!J5?t_c97GXC8`V#tlEGS zDYb90!6*5Ht=R(YPDz1j?#^l)`?g5wCkJZ{rafL=42!(gXwgxW1Q?U1D zUaQubl4Nbg$;GHqiU>J4u{mLhrM%wWD!N@4XCcLH!tdwNKf)1TeJ3z#whe(c%N?%A zEDC5+TanuoM+G1o)K3NZRIiww64n4RM8}39#`&rHO{#c#MnNmz?F2^ZF!u03~2wM{)uV|^9=P&nl!uf?(jMy7zPD}bdVV<2gfHq_kB@kcxY8fKfb;IWnP$85RfSHi76(H*(qF!t97!OC%`_-RBl-;52?F-P< zmrJYN>97p2(A0ENb}IL%jo8^{ZgLPL8Me{E$bhGH9m1fM>8u54)Wo!cr9v75ZIRS8 zbDJ);v)s*M%lvC3OYEaG%DI3V%9A!au(#{>z|mBpBJphW=6lk^!wxb;EA5$6KYrYz_ z`yxl{TELx#S4|-(UkgOZh+akr;gX z0~Qyn5~`3@x#$Z$r@{5Omn>6%T=SCIp{|Rfw%6FxX!s-X0=h=Pv=Pwuwigw7D|fDp zmSA*<8wvJidXj#u*XJOQ)sw2lF3FYL@wiN|tBvUCEri(c^ST4_kU+tH4R}Fv6UEYG zdz+^=Iy1sy%P?&-R9;k6(eUK+a8IwZe(V3-tkyQwffFox};gZUXJ2+zwBgtrsn(gxvXnH8tmr0kMb(a2A*W*=^rvuH@>D4_h6#*{Y1+ z2S=6s9v5qxt&eqZzpWxxFH&^Q*ZINhczLScBp%$KNsxTDTYeTE(XYpLS=hS3hDs-0&f<+YfkhC!+Qw82ip32hHopfhc<8FUomi{~uBLfzZ^u49m53x8%% zSArFuVLl=1{zf6F(i z9o{nc354pu(qMnB&CD5(@jM!1t8OsndRI?_CT ze~WHa7!~G%&{1UgI@;d&Xkf4{uti9a8Mem&4}JezA$#imnSS95e}3Nc+tBBem-qXR z0e05;Bm}qBV2Cg9@Cc1@l{MX8a~7J|+;&}WZ_?@wg`Z~pIS8Y^Gntt=>~=5<-jeB) z>&S%O-XIXYuKsBx4~K?&#wEZDfP{wZJD#aD9EuF)6*-3{1QHUR|AW*vpn4%AS{HFxq^p9bTtR??tXn~~%Q+Qbe? zMCWL0VE@gr686_POK30VO3ky<^`=K}x&pslC8*u6izdq-RQtjP&uN?1EEByHmm^kf z!!Mc*9UM&Vh*4W~Xte&fl)4rl*&;7ef#r)Cji)1QVj~ zukWld)OM$*Yiar#y1R22JQR7|Kb{KJq^XbyHeQz!+4WZKh&hUw)8p)xg-CfAyq*Vu zU2TMs6~vcE7vrr+xFXtvU-^-hJUx-Q(c3BoK!U=eNZl1;$o#$f(f6ZL$^jYiqSXksHe4mOlD}8_WJzGXVK!i8R zM2f?diqD_v&RFhMsJ_hQ=ZcIBO`j|HRZispeRd~=|KYMPFIVl^jpDELk9@oyZ7wdd zdbJFVB{8jJ9l2Gjp9+6XPwp*J#(F9~;o3AlEL!$j9}B%yVA0pq7o!VX3lyCz7{Q}? zKx-a8D3{-}_vN&@%j05zSaZ|!{AvDhJnCX|QNN^C;b(N32TP0lIDJRf&a@^!$(cX_iGP}=aD<3?xIqpRBN~Kohc6T+$+Se`o2ef2Y7sD}bLf~rSw*(uslVlC zG;TVp?ZV{j=C~WpbnAT=bBVdzdD1`vRI<`<*3pB4*~9(cE>};Y_xQ^O zEC;jf)Qc6w-=3gf!dH>yb^^|`$AxS4yZ25@tH%^w*lZ+ROWe97rC`pTWj=7RP;aY1 zJb(y;pxycQK~9S)x$~w8(5G4V_;_!<%=UAXO#>eMbU>bmhYTH{lQXK6J6xAoTDo8- zd+4FZ75%zvN$mHwQjLnrS0#TL=sH1Kgjym$y(S*d@_sPwI{24t6t~0wS=dDMu;K5hl9Zg zpT!0Ms646qTWL~Db;wh|+cKW=Yoj7!)(~4Uw_TZ?3|sT~7rs`H>Jlimjt^o?f^53Q zev?Z{X%-YIG~UY(@Q{~y9u47!tCeWs(-ZTJ$bCYq>gYRk*Y1)WBfRtaTwqHY!& z*s6p8SmI=rJ}`c+aZZ~Z?#Ha8ozPm;I8c*73|=gi`1XrBcvQ9^G&B^hz9;-$>};hv zJtybwym?mWc8a~SvJx*%BK059loFYWOED$m`8U#xVSyV~kq5=nl9mL!bXaq$>WLWu zOie~SNF-u`AVlI@*BcMJ^*;aS6$?g_!9g);%ZiEytHIrDL4mCUy&<1xA4+$}uIo1Q z&DE>Rg=#-}eT*D5){LK|T%Jg0x3}rdZS|YW1;&CCQmM_+ljzL5Mw+Uv-uH3$lWo5u zGad{!V}9)SddM&ebTf2r*!*SdkqK8%bU?%QBj@(KeK0s+i&(DLLMA2R^Ep4vY!>YF zm>8ej^(4uhua@S|Z~W}v3+xw}xO4GzTxt-a`rc}@&zlAmYb#UhM0MB``f>L>^ zt(%g;NOCSFf*25|R7g_*gCPh2JS=DYK95>@X1lkf zDa2iTRP%Rsc6X&=W@e&OJl}d*=O&6(u#)FUdL-y@5-L{FA6#g*wI_L z#Bu)Ij?Ly6!jQ1RafO44CulcbUJ>@!+h|oo7?Kf20oY@`?`hTN%XKtZjuzh2x1O(c zJ5IhB5c@y-^pOgd4j(ijX9h{u=si#onpd(Kt`Dz6p;HLIK8(JRl75*hnpf@lANV?T7#ohL zub8IlekqjNG3(;~1LZiZ{GJ8yead9TijJ7YSb*KoC@|RsZp@=?Dt+ z@5JJBpK{wh-GstD#^IGxV^`vMAMIA!7|ptPhF{r6&w&4(kvZk4prU#lShtMU4+y4j zO45+F6_@@}5K8C|$HwElIXh~}(VA4FhmFwIn0g4nq?eVGGgdaExDe5(*vk7Cl)LcW zmZ%0<2sm)Q{&^hLLO@)1RfGK!XJh4FaB^rj24k#tt&-(Saht&Jaznoou~jbi-puBg zdIiB(PlpYL_A#X zU^9Hhhfs=K0q=fIehtL=-<m*;2sPxG0M zJGY`gr~5Ax&GB4uHdZk@odh;}a(KG#92l6v8Jc}FbT-gjFocMMLbg;&OJ9xWiioOx z%;f#;KJ2G?9y6TdvHZ$vUNn*+^IeUXQ&8sOMJO97PtSqLuF@>O&{;LL19CX^R2|9U zgGn5bYRSMRpm>l_o8ja&C5EGS!kvCb7K!}I7X#}UGxC2v%Jtk4Wn$yw!m4(biY2st zv`8!=yaSH*Tvk`6x-fHqd6o2lpBgYJRab>3{Ad|xP)X+wIzkC)A-$nHMb$`+iO3>@ zyo~LAe@TSm8Z0EU3r)1pZJX;f-pC3P&7_R=|Nd%7s*Y5eYZG8VN z(!^Qi zjo_FVh8(~&e4qt0_{BKuJ1=?VJOK{eK;nKb|2yAtd~>C>rK;*8J5wy7x006^O>6Gq z7En5}&rn60wlaJ_qZc^TV65FQz3=yfCs7WxTS>Y_Bih?#_0W^Ru`!E@c!e~1sWtH=^hEaZN1ib!&B`kM)u+#1Vn zdD`J(XQEcgZ$*%161d&wxk@>y<+uyHT3Sze>kf{y*M)j|b@m8sjE+X|7`eOqg&k9` z-Lp687~uc>JQ}gHzPq`J-prpUaew<4-`6rfL@K_t#PH7#smCj!d~wZ*1mOS}R($0N z0v#5sH!OX%y@`9LfaVsQDWU#>1UXivt*(Qnh4B(a`OV2ER35!i8AK7{fU z|Gqm4TNq6HA4OT>aVBW4&>JGKtW%4~NJfqF&@(VRy;UPOH!&e~S=x1Z$nbn_Dbwm0 zv>88l-cKUj+~)bXm<)!;<`PVw{CY;PIN$DT)`<4-+~xFL-YDhww719Gb@`K^vSau= zvx-`YIG5e>B^G}b*Af2k%cC$V@h*Jy9phPw1Bx8nZ7lT=3_OrWJIi+_W*=6=ElZ*YD$93WeWI+_tuML#^mgdj#nVfsjIWo#iOdg z>L@(PyP|@p-SYA8Avx~xY?y?k2}12kE}!hT!#JwdxSlRX9?aLaEA98k-NzU_%*7U3 z+s>8}v)c_|4q@o;v#Z@!HEPYK--{kw$XpzNpUp+xZ>OcvKU%M(1zn*7ngDUpj+?!? z0g->>HEg&hCuG%>@_Ac+a69-D2Yik+Hs*GEdb(2^VMvzVG&M~j93x-sWV^hyWS2|1 zB~uZ0xw}8!skX}GVNdeK>TRehr$9n{3P50WjAh`kF(;h?ew|5=+D@%u-4M0iPilOD zH9>*c-~!rWUVzlO$p?wLGKu}f)d^leqcuwu6e2i~j-N!8UE*m>noCZ)Uogb+W4-S^ zoHr?ynJsCfJxVtSj9U>U@IU|Q3CfBC!wn;q^{AMH@?(4c(?;x@+I@^NT0R_M0CZY03JV)XT61q0*!%PIpzpJ3X?nTCA;hQ7 zq0YK|fvaXK%Z1JTFhnd^?2Xq+VJ`zw@8NLHhkP)}F^yXNmCGKFJMzu^@~wJ(75bs0 z*f>O*!LhOPiUJtCO&To+u$(c-SZnh3O~!c&<$qEXOgGe4#=%9sj5bS(cdn>enNCfT zoGX%KgyYR$qa9;@`>|amHX5f3&`T-87pH7>_(4kUH-2Moeq??S9xK9^^h72};z#~C z(xgGJm@CNcvF4|ecu!b${zP~h)Y3n{y-ogn{Xy8p1p^I9#bLYmTdC=5MXh#|+YUe@ zIx|4`#C)`4Mb+2uGF$wu9x@Hw@w1%2^mhZ*Q5lG6RP|nrYPZ068 zJS}J31SWgv3`d}*0m3a^R$8ogeeeDP&+UrSs_9#HUhQmdhr~->DOq76GIs+$r%T3H z>kW$8v$ig*gZWQj&!6epT3Fc~@B;1eNaOfW|rlK-3GTzp-b^jMh=NKMW+qTiB zandA>ZQHhO+qP|+ZH&gYHL=y$Mq_)DeDl2DvH#A`x$k{poohocG z>pPv71s|v{*~Ltyu-ETmXi!ZN=V>e?d7L`698G9?`D89{K3#PlnHRMF1_v=TCK%qC zJu%SK1X=+z_P&_R6$lqeCbrJD-Qk-Dg|KR==rteDnqMpWJ+7;tQb^WRW$S)GA#7*x zRoGr{=P=P@93Fk%J#y|SWa{Xssp%LbKu=6>@pBOGZB?jdwK~WTG06vfIs*LPo_NWV znEBk2qOF%N9+9>cFXqKll1B&iJjqDJ8B)G8M1juL=b@=y5kBgD;~S2>x1ex?`$LUf zCs+*26LSmrTLaGVyfs{SnmKQzKiM><^LLU&*>fc$J7DXY)GZAu-DO;tiE1 zg^CG)srbNMVLYoQ^8<#DyZfObk;MOr-xH&x#{MoshkBfH5%z+v%fydQMT1 z>g4Ze-zU2*Un5V0fIL7T6LV^Qe%Dvu%=t>)-tX-$N8O!m&&>Jcgp&nTMgzIo*;!eu zWw^DelzA00Z&{fnKx-5cL+R=Jb%44ao%!6{(A3m}qF^8=@B741`c8aw^kB0RNj&v$ z8EsZ5F8C@9INjyi_x%qfruZsfP{ardQRvg*7-<3Va^v}!b@rp#%UG)iCVWUzeAps% zIlQ@L#d|tycKeIyejuE{pP03ti1E7Zx>vB- z^ZI<2kJog3-)PFwoIrwBdEG!?LnET~@Z@ARkN-WMbK7dOy`eB3bh7zD5r|YOW?mKL z0-DVB@itifaK~u-;d>w`+2HlqTKt%J-O@2D>PIOzk7@a(@JAw@9TBo z3Zr2#8D`cj9O)mgvoYnv1YP%ka}8C*d*6{$=)*G0W;jhv%y1u|eehNL0pk1q z&~&wv;Nb5IjiKkM=)=J}zBj7?1A`_G`-n~6!uoO;%>A2Sa7TfUsi~ttaS8U6(Qk(< zUp>GkBmps+*V^|h`!$Z+f5G6i=?q&S&ibq8Wq&pn^X&yvie)lAkn5J+ujvt+i|90g zx7JKQ?|DJ7tIaO+>>w*N7WL+HF$^L7SxyhLTL*J^fdk||JJK}7*FXIrsVds@ z=oO~Lhw}j&VkfW8BEyfQfz@&dbD`^^)CjFK022$3GixK*qWs+}Fy0?*Ms~hbwNh_9 z2!(*7R_AD}Po!k~Ulpt4dn&dI0>L4{*3X$ZG7`oXp(hdl!Zs)ZK1{5mhqc+h(BbfK4=ZPXz{F9|(J)X@FwoJH5h+W# zkeC|9sv*_5fjK}0Y@XfC-QClT?KmXazF5?DccDzO4ZG><=E+6TYJF%Uhks~o^LQ5Y z_I|>b;y4zQyD0efMe+Bqqk>nd@u1I+LjL0(rD11hXBw&N{=c~<$rH>xA1qXgH7Plt zUmx=8$ER0+YgARZf!G2M6nc$!*Rtu1rU|R?i2_Tk71`;LWt;if_@xOm2xeX=u+U-# z%CV>%=-M}Oe9(~`HM}!VZ_cyZHQXKCi#LA;-(6iDygGd`FfbjqdZ`&11;^YT{_YjM z!^hqY4wB+}+v1HA6~vf$-vc3f$(&qmCkl0;)>OwosO zu#VSaTTgR9YKdHfcz9s!ZO_Qt_Pp7;mdjM3lL># zIiqM3=a$ltQuuo+|DZOuX68MgZ=MARF`%rNECL811WL$9B;W`q-AHZI=lI{jdV3F~lwt(>2`i*z($lVx z<8kLVxVTDf=Co^Zp7@|+9DaSCJ92Y6usLo#z3-erK|*V$X+Ct^&y-0db~t?kyA3V6 zAD%YuG)}hKnLn!PZZVPdsz(idU|k84!>_m7W>JeI@&!DAH=m*SU7i4E`uD(mKOG%C zzv~z0I{%yC$aG6>Ywfzvk32kfp8L%XblKm04;M`?lPIGPao?c*(##<=iIfI^a=9K) zdG7eP`95Fu?e2C#BS|pKr*^Bl4N?2s(nM3U{DN@gNaXu{chep*M?(IyTu+vOKRH>p zF+e-p1^$iuMr4X+M8dS3#?#IyhGQ@Jf)rLdC3$EErvVK4bCZfCh0X`G15g6Na490h zF#kb@zlnvQ%3=l;r+6NmTXH1S_HN!E03LR$G4!ybeD_A#GneS84l&Es`vdVEQxf%b z$%pB4_RI+LTDAfK0xzR3LZ(%cKvve|Raa-kAM6gOr+REEAiviCy8te2*BDaqR0M;G zC9rQq_~mgKvw$Yt=1Tag$I=sIJdj}!FVN6w>9W$g)afR^EOiE6TaW;Mq?Lf)5qnPQ zlk1HNLZfXHC8KY2&s{~nqfhqTd_8{Md|;w(2NAL-9}rsKKfqy>82LWNijd+F@8F%> zT=4knJ?m=mOUNUeF|BQkg6i_q%gO+CvrUOW+o>K=R!05a6!+o5!5RZEFDI!2Kpi3X z@*Eihr%^exnT2_4MG6g=*QTBxvf%4ac<2QD&xb(6JXt}(^4Xxs(YQLzGQ^?XRX^~I z!y^EKtH_ZZwOwtZY|ezK1xi+N3QSguF1zEw>1lYg)WjNUe4)wMQ?l=~{VXlj%L=&S z^vA(U+M;S^=E<43*!~n@n9J_&RY=Hy0sq^ZiLLnYV8=-O{+S8iA-;+OgG0L$N^t1Z z)@Bs7MNl8sJP|u_S4IF>wqPOg0RgDnFeB^kQ(4fcfl4Icw_IQ}nM?uyg|ns52&68bS2!~!aUYjd7gHryQa+ZLKw zUh+RRNS;e9Q#y0R#N&V5h=OYC%j
L%(x+4zJMb7+?sHU$+ek0y@#}g}AYuK$EG# z_2l2n3vFhmrDE>cY7H(Z>1Z^Y*H=&w$bTOPpnl0{rvQ&Y{J4C-zFkyPW(~szE!`5o zxg-+?0-5I}wtIV@LB$HLBOrbz;tar8IbErpk)r0Gv6m;dC6>ka5YJbb%4JKXP@G$BX!GF=I32b(w{x>96#9 zf4@5pdt8ST@#9UhXl|-cire86es%+1mudhj0JqAwMi^O;2~R-Q7&CZc>P} z+{MMIxr{fs26yAJ){$`Ld9`voR^!vBoZDSJwyLrLM`n?e8~e;6sEv_s)LKo zKe8*YfKQK$J}}5houILz`T=522R%JK5s?898=uag&l3Rjx=I$yLJ<}KrL{A7rKgUL zj^s0#`UeM1cY6K7BEpFFie3ib8Hv8a-HOpieJY({Z}!&D|9M?RHSdY06YzI`Oks;6 zG1e95@dtcyA9=Ynu5nbI|Kl=WN78ak&?&~x#}y11EuEH-UT7JicmCrAz%NE z?_e{$zSA8?V$>@Mcb?3m)vyHZE)LWM1nx&uHa0dU$%!Gi4QU$sFM*g@&|fpz?mXIi zZdWdg4OaGo82l@#KbPC>pAw#Q*kqV@H7=)v7%OpyQGbZ}{?xRTyc@GD%<-qGFL79ch)sPYa_7*_CCE3A z7j|ZsxWhw2&u>e#Qh26HTU-JO0~er4u<`0Pny<1_so0#XDD3G-vH*rjA&8#yhB30- z21VkhbGT9K03f)2o7W}{a5`J1SC7NUSH;P*#O_z6^8gJ4-CK?<;{-MqyGU2DAETCc zP_=gQqLj(Tdi+6qoC*CZ*-(BAk`r{&wBpNZsz6wcE*kt^BlN(CrP zN4iFzj)@oCXP*!>pJ(32O;-5&AB0&W^6daWE{IK_T%PwA&wUkkq{js=o6pUbX+&5= zq*E^S4?0G}?wsV5nNL|o{=}<_@kSS<1ALw^~kH?|t1S-KE?l+yo!(6yztk3{BG zQ!5mS9T%)Mh*umsq-k$`AY_MZPKI?K^b^;_oH}j(-YYEMDfuv_On3$nLdQVu^VKBH z`*bmonSI@T=aU_V!8HMnbCH{2sN_W|m=2r-OQ!srk*Q^1+z=U5NZ3wPL($@*h1Qb6 z98!FFc(}>HJU%)3E_6rWLTT8Y7I1tq&eQ(~0Y}XJbQ50v{mB6&MpqqMb1BVg{QX}qxocRy1Ohn2gr4e-8 z?xxkOeYoD)(KP3hCwE8oGPqHE%5@@UM?Z@%ZYrLN3PxSx!Vzo$lx5xa%_&tMZ~$Z7 zSmm7AWpFOl4$TZ@6=es}z3mpueXsxBKdd&8z9ara-myDI5=^7kH({52x@A6bI~+?E zdQ6ri5H@hRKb~HdLWJi?CLOUpu7@p8w(&d3^#tLVNy(ed)(a5`xCXwL?A-!Ed{56} z6L-Hf4PHUu9yAg`Trty3!|Q2TzLUJEiHRK??5PO6w@z``ADmmg9yeGg?oKXAqu}n- zP&W@PF)|sL9~J7T8!xyOD;$e5uZ!pjB*L?p~)Fdqnhy&@STU-G)W zSlt~i+iz43M_ltxQmV=;1uMG!K@6Ec)R*FV(IzKk5V84vuyGkHrmMcpHc*gkJGQC= znZ2HDyK?3xNjhqah&>)%2iPw0^`vtan!v3?&y@(v5ZZkb?ptl?@yDUlp#kyWi--_* zr)WvUO-FNDHssi%#*RbShQMa~t>4F&x(NPOW<~8*jjd_J#Iy(Vql~@h`wlr?;#Dqn z#tZy4iJR1v_8H$Lcq!#dj{4@cgr_UmqnVAS(3H?y;hGEm6(kJQnkJre# zTME@(s=i<%V6x zzzUoMF;Wn#PWa#>zyAT1I|-64I4ym0#u#qO|)NkL*JcAAM~>2Ws{i2&p? z-H=DbW(g~PVj^{XdpM1H4TdHth?dO`QN}3698ik~PZu^R(ncM8*Sui)p zD;ZZ`Ux7y4ePD`#nkH0mOzd=bIL_^QyxTu|%Z?=Y=@nT_c9$;RBE@bH48&Ei2wD1- zM-}>QvKU80URhe)T}Kw?EcQUNw0yL#=IOC5t3^(yOpt)x+lAX`t(R#vNh`~Y%=3CR z*wp^!^(fu2KD2mMEG1kDB;JZmP6-nIpZsJ|3u3H2BEhGfi#i|o(}@2XcEB87X5A9eydZI&5V{)sjH;zN80n8(g%67kuWAg;y0`Ze?j{xneZQ+c$_T`6 z+-#lsDI5eOPc><0XL0Q40lT~Pc-#4z3?6}Bu{&5&vUY^0d%2m>rt~W3D~`>+NWV=_ zS`;t#k1kQM&~OaeK6^iSXNq$wF_icH7q;YKI6q+HSlVRuLzc)xJrRGCANt94kY>8gx zJ8cg(wPI-n-}jkM%2W+_fq!0HN7F=>h-3%bJN1BMg)EwXHy#bqu*;GAO+tt z_d`sjEtCnq0={%VR(8gGdx0|>@&7KQPzW%)gQJsfjC=99h(S02a1h1_dK*-mq5JBw zB_xa}uRq5V)Z~KPRG5nJd0Y!s;=0Q(S!$>n2fEBQxPJeZu8R3VqaU z`YbYGNu{V|>o3!4(2`uR;&3?$vc`dH>5as{%^qF~DH76XOdf@`&u75b^!O&raC0-; z1t2cp%+dTO$f!0NFu{Wh<5X|X8U`Dj59*#uExSY9hfJrtP-l&5Y+>}JQK2Di_*jmBw&iRkT)2XOG#qDC9vDnQG z4WtUt9J;2CcQJ65qBLSUhq5n+Wksgi*39zBNw$%pf8vn@%ziwvC+ISFHhF0o&AteK z4&x=pFW&48?)X1%0lq$QQ9bH^k;xGYNBkh2Pct?)URztM?wOmL1L^&Y6$e3yRB%$8 zLjco^FbbaPt5s>-uQmm6T?`=vNKxG6LY9~eAzJE`nezTA-e^Y?9P?E}2%ZUMZgJoD zP4|@uQ^}+Oco?{m5TB?QCn(&TO-qoE!{}rNDjrvrKbq-Ohv|~WNSW1w{Xd~Z zviS%qwsbXMq|s_BT`VjJpSGoenk+&~YuW$x&b*kqPQNQ+5~)`|4?PX-iX^0rywa-! z3#m?0qp^8t-03H(E#fmt`mlwq@WBmr4trv#ZM}Hkw*AKE2N0esZewBedARK}QAUQ9 z%5wJeE)Mo%;kFAzy&P^fz1>FZ^yrG`O%id@ zNW!oSm^kCStz?79NEjA6l(1k9kFL0RW!NqrayA+ajh7ojB_#wnJa+v(-=JR$H$)=ctBLKz z{$cK-V&Meo+$G1UaV2rf@c7VGH^?m5Ae16hHABDuaH>$}=jY~6Xs;zJQ&BhIYSM^3 zl-MxJ)8kY33> z7sw5GK;7Y7u6wLE*xm;p`q$_XOmzA+h~F<06Yq_Q4``W~%y$b0fTB3EfZouVT-jlW zhA<2IS2K(ml9g-CPVw717kW^pah#DwIAg^9@$QIakotzA6xf?CK{3$ z_!|%qp!FTN7fJXDyq|rE{(<@Xo2kJz;Pa7Du5LF>+3(+<$og|^4;sI#j+0tLKms2H z2hWh501LY!gP}TJb#i3I1+C^36vA$3Xof$q*fhBqX*d-ntnL$av?wwebshnVTv$?) z4&iF9WpaPbHGz=ZhqJ0`gRZJriq!6+Bl;c8+K}7FY>xwtIHIHcdD9gMB^k>BlKmBrYm0uiv4Ru50q4b$a0A!ZRWH=h`uef_Kt7@!Fq4 zD@?-u%~wk`s+gY#l3Ez&ky4bf)T+rCbEyCd6uBJkC}nr8;CIbH)J#FIs_-TDcWHOkS)L zLr9ZgTIQtbCm9Bz!EV8`96ewI(Fx$a#`71rY2BliGg-Y~A%mw#lU&}=KK*uxXyBsD z=A6td@E2xs#%CFCkXUeoeqi8(${ww}92WRs7JX|{2n_cpG^V!OR7HV)pG!2U-RTDk_~OahGjmEd_U+DM|42+rkmh!NsaFm+ zI{!!I>;(4{(Uzm%MC~wUswD|Ru0`s^-k^y>kg`laouPl;#CQQZo^vb>P5Ua_ScPp| z{frYOR!Z75ZCG1aZMBFMBiiwu-cK92Ajy=>iC{eKQ_iG_*O4lm$&`-8L{BVmD)%ma#Y(WCk8U& zeGz@PoXyJx(KXtoA9JLM&CGa)q6Rt_@345S<*Y`P-$ic4%STqM4FMT^0XVp}VdgGi zP3lU?xWyjrG$xVdt`bpjqlb$fI{R5~^F)U+I0JguOjEaJ3=5Gx6@Igji zB}5gUvTL4?R{0M9!9xd-l9wNjxo%!uoQjKEL<7oETFZ1N{0W0Z`1fKAz$vB>-tBI` zxp9)?wO=TePa9L9`J_m8~XF7Hle)i5ncZ22o*(+3oZm|R_W_Hz`BeJZLEhD__cITT9sveI0Gb}x^a zc-u=9UT>ebpO*!@NiXfkr`RWf*^gT-7Z-m*N9UrAy1gsxY?Wf}Vw`ch8=k?uNyZdG z)7EybY8r7Q_vd4JWB2n;+}6uM^tE~uabsl=&338xivnq6X!JnINZNMtxW1D#N={sR ztrPIYz(8ZOREwWb8xx)A>fszbH8Y)kjVCa*BroGcdeGspW|%>{Ua3#Q@ja~*XNv{$5>v{a=V_$RVh;VwPk9NuNl= z-md`wI1XC2HO|jcL(*bB<-a5i_+yvZZF=2K!ZPwWH{&y%%??L0H@et1@tv)}0mxid z`}Ijt59dUFi#LhdK9jv3-8Qgo31{cl)h1z#5oa@mQdZIu(Ad~us+DL029}g0 z$2!5fOe;73@s0}|kw_|J;*75jYd?mQMBmzU1Y%RX=~!iu2H)|A=FqL6;tU+%00IvB^i_LAOUI zO;;y%(Gd@7q$Zde#DDGsw?3IMPB;YJ?>8J34cOAac+_#pVxn7NL5RQ@(L|Vc3RS*T zkdu3^fs`c?KQ#r1%_>R|3lww@gb5NtRC)W;>=gx`gkuI9Lc`;G6MH~-i^_YyOWA2- zshBT4E42fXc=z)As;H{mQ0f*+HKtZ+)QDihRE*yqjDiF^JTiym+FG_u^dN=j%YM|{ ze>pB&m@$cok+l80*T9;WgX{VI`M-+swmU_sO;xOy+&bk<{ER%CMPaS3@Vpet|7h+i zJW~r!Ig8g4$33N`tu;n|!Jn`==*{+TpC!kii7z90f(R&|)3aZ+D>o)1-B^l_b$vpD zjE0PFZIU^;E_<8gV=J&oF7BR_)8kzk&Lyi%C1b-&OFKJ^OhZVg=52pwaA@q~|k zW5&Y$VBDN8Ctt5Cnmf}OcT1k~rbXGqqxCflrPM4$bc}I>g}(wnSXR9Jrt5j3Wk**s z7=A~{+l1+q4zNYaa{nSdB#3+R^Z%ZlEH^79E6sm&evXNoy?ZP(4~NcGUt$_ryW!~| z;D2nUt?g}MB8j9pS(1Ag!J$J^L!2N~<)bzpPt%s@`#nx%u@1}(P|&Ba)x~I7e{-KG zAY-*a(`wBMO*sk($OqUTjHDD68hpI@_(=N(*z5FI+^@5b4fFBeF73|o1O$WSc6q-0 zj_0fHd=dSz&$c3KqM`ZyDG)ds5wBI?IokKB<$88>u+d?ecods1=l|SuT%(JNVt_7t zGPBX^&C0XW1C7W;!6VHVVU(7wEUveymI&HVb$j1Cd9bROt_q_JqQ^%f%QiI}!|P&` ztYgsYxc!?Gp0&5;*Lu>8v+oMtX^^%2ap3B>nE?!R47@qQ)Iu{<)8aZ3yiP`KOZ9mq zdMXQiT6=A7jv506b$rGzlC?BC&u$jT{=DvY+zSL8Ot)Tcu-lRzLBztt2?k#GxK>o$ ztTs6sHsw~9G7av78iA3;#pi8bq*HyDEQBl6YWP8@WG+3JahaR~gdAu6K1wcA2I7b` z$L?DwT*S(xWl3}nU9;CWYk}A@uLJBPx97VH)_=92ZL`<$^UmYhN}YF?hZpa~TCe@c zSMP9}^G+A^9yTLAJ$IugqlFkrrSZXt`0&UO*$^qHI-2;Exlk(jJc8@RYm^bGZ&ns? z1Lm>=LOn)Da;2PSaySs+?|0sTO(XLAXbvJS)JwYR>}`ep0wX!9rE>k-)U$$Le681H z0uAPrt?YeloO~?P*_@PedbyzWneiR+jL!h<oy$Mm`hSlsiI z^}-O(0;0*?H*5y(ViBr~uvJ;w`-=kEIp;LY1M6m{78^5n+TI!wN0Zj<@M#2fN@`_tCeWuH4*WiGqh!EEwc^(scJ+n%WHR{Z{8bIm^~?dOZ8*W_ zkq{LpbrvUP2F+RRb$285eSUBY^fiWV%VZoAG&mwCY`=0c>P9IPZ0fSpi#Y8PHB2$9 zTplUer^_8fnG~6|@r6@t>}-%e7vRY_t+0QM073!Cy~oGK#teFVjAPLova-V)O86}PVLQvXBlg01>bYgM7>?|Pp{S~z6c>b%I@gj9L#@G!kJTT4kJ{fM+ zN15i)L3=MeJhGCKz|?7stCa!o6R5x6MN+lA+-?OF6#PVONqUa9`}vvAXmCE)GTohz zIs0xtzC^oX3=_(4X*oKiP($GB>^J!Y3xcapR9V07?}Ifd4D>}L{Vk7u<~a{Wb7nEVA^_FH;GeTg2eO{jo~cSwDL zD#*JhC;JyzD0o_eg?VLV(3x+Z{FKnq!>c38_8ZMwjdtiLQLysa5X|R|%<;^WjD$^e zzR&+;7Lq%>A!(N#g0zv4s4Zgx+f}4Ch<9)>2y}Mpe~=YAY%Nb`8?!&hbp+z^x_9cE zOy>GNb+&i@I17~2vt#qX`hkGx{Sf>YgpKaBx~>dQYCD^2Ioe*)P_s#di+3+fXa@rG zQnDh@1{GZ`W^_A$RuNgFlfOM&JU+#uTK`==VPPAZ>OMugN^P}I<*?T_)bi@E&n62r z4CIiXtCgpe`vuAm$D~JUy0QIUyH*s;j^K}pXU+Q#S zZ*4VIwI=;Dm9bmH-!bS4qGAvrkdwA`KX<026%=OZ?=4uOVK-lI;ZO43bFXuJx3j(P zXExs%D8%wzPfHvDpHS5iNZjR;*W0UzN2Bda1%nY!xLoBPm#b5wu8cis?o*)OytVg> z_Qhcb!eavK&(hIF5e2(<3W7^0a4Xw784nNN%wA#1#k-l@zwdN3L?g7S+&USTZo^XM zIE~_a?n*++K|4(sJus!Y4R6~GLl0C3Dr@Q`qIabm1a9TeU~{@pq@C7v&wK8Iv6cZ^ zEQaBfKCJfK&80el$L zHXQY=?zP)Th(E~xMh+$$o#_qJ$Qcsna;R;Y8$NIaT0zze5ik~`$<{O9Bme5Mt?R+a z7ilN7U8+P^Afa29AjVqaUnI%*z3NcN>-;^VMj+56z~r#7&#wEa=qod%q&Q z6M7kCKV34#H8FNt&9kon@0T~WeO`|c^*cWs{np;EyL)VoY|;`N5)mv0rgHf$2Rq$Y z7fu_UPnHuLX@`laV+!EgG50BiWjO8 z`44qX{cs#|h4*MwN)#lS5zcKd(kpRg7r3Lbu#PQAFe>$3w*i*RZXgis{ZewWxyF0< z`ddk@7S!U7ZC@-JQCn}Z0J{aP{e(;>nUWd{2T59w24-nAEG@v8g>~4v^H^EyfQ&q~ zeX3-8f)uQM&gjY@d%5NoSNABBDEer%JuFjwEQLDB|NW|0lGbl^xf;JpuoC4Dp3%UM zZ$IL`qNGH~t8mEGDUo;{lz@owXNPS&e)lg$i2~14tXew=qTOq^3ZrS`kpw%bIexNnW|S0>vuqot=ZU|A@zA)J0WbkqdywO=+2Q5 zYVkoj4mbW9lpx%0(A{{fS^<15S0nAbnhlOjr|S#WnzlV{v^wfGnLKtza5u<1!o#=O zt#_ZLFTLY0wte;a+b&n$PGj!$+#FT~cI`ivcClY_IvKv~)CYXT9MML`^K|>C^K&O9 zY;7D^Gj2A+l|bnNFTttz`Ac_lU-uGYF%!A&qIQ1pI~~6_mG*od&xJxyjEZqMsLVBa z-y;bEAKnQuTCH1n)qhy_e(mk&|HfnS0(|v+5WLbKC-{34m)#x_Wp-U}votvzN4TFW zk5;HRJ;k6`0p5-G1oXSi%$I7EdvpeKy0hy(76yi6Z(pSD`?Mt?N?uD;11#6;oj(%r z8?1|MnP~0wE50+JG8nf0h?QoFhhf+LY*c3nv;&&=$}$fBzO>s37&rx?SN{@U67D^4YO|`$$d(Q&1NL4WmnsX%)o#g400@9A}u$k+EQ) zah{oYEt4fcYIxs^S%*h24ms04jxccSn<{U{9k4bn;&C`T&TmJrdR{H2)_twd0+}R{j+oHcMh*Rs z0z#p6DgL34YN2>LCvBqY3vD@_Emou(25^4{+>KK*iAf{7Dl}|yPsCc8)cpaOosksd zhnCB=#v<$vl;5bJO+5%`QZ8Ve_k~k~_(2}D)n~rxBPkD5OeNRRCe^R&DnD((Xs zn#MPhB0~R3GPApn*RCq(Gf7wH7;e_Ea|l{X(=wTzAPW281f^6qQpDtB;t8UOGxi2w zU0W8}E3o2`#CxsK&t*Q_JhZ)K>{61)c=eSQq_X8;W+u54a|cp$q{Gj^bi1l$PW317 zaLrXbgf)5=Aq0DXJ$O*?!yZpx_$(!n|LIhc&=`Q9Mj`mJ3tlN}@(4cB zUnCI^cTb+mt#_0l?=-7@^%)wP6#U<##u|jH*sH36XQ>}>HGYyi;`vgEDL++KmA~=r zSpVtL<9*?I*P>e30kKKAplhAh0q~T_X@F>^;X(13<=JIt?jv2m+|FUY$@aQ!X)z6S z8&Ee*eBE+%t+(p5u=7>zu7Al|dwaMJm_vp(I@u|bBn&w6=Inmj@x=IQZx(HP68{D< z1YWLDV|vn~v^F(9Ezw`-;IVw@o2S*dO9@QMK+gY)s`?D-?|~v%&o`}5hgLO?v4t=? zMD)m)3cCO5et2U9R9`K)I@*~m6zQf@=K1ecz3#MZG|5ZkuAAu>`NWGxTs$A_=Ue#Q z{N48Tw8!P%S2xJ#~1%l(~c z6-74t&gJllMG&7Lz`gI?)vV6hsQ|Kfi_?PrEqloUrZo zuqy~UZjPyrqFqtZSuLljfbP~;lSQ$e540^H_H zXY8?oFJR}rxSY0m+1aT}-?}`8kdn^v6K;lZ5%M{_#07V#SV40q!Bz?3Tu9X$F3oZ# z5hNPEtlVEzz29I<#0ouUa=2go`Y1Te^<%>p`gv4%;7YpAEX?uweHMcw30v>d&1Eqd ztv8FE2S?&Te4)|}4M~N;M}zg3IDMlTPB>Q<)-Zme+z2?fiSF`GSH4uLTH~wj{LiO0 z=XCvuc=k4jwarF#+ly{BLC)DXFC_&99SFQDkQ!$TL>Oe)MEYX!DSrPY*6rfzT44e9 z?aXks9|+>k`sunss7(=GGzbj2JwQ%%A)#l>Xy-H@fWXT_vl$vgGH4KvEGzk09cKn@(##Vu5tp#r`Uuz2>H z)w#)do&74Z;DZn}@PWVr)a`1b>S0@NKiAzsNTZQ3d*^GI9cXfP|17Q)+1JH5cA5Bd zn^ZDCm5Arq$kL@0QOx;4gH^lu5)v7xmOfO6xHFaib+zFLNF?;KoD$MUNqp&3rMRtu zy(e!dfY@m{Y@-{-$U^@etdvk)uLdni3{DJwI{s~|=go2X>;xpYQME}nz3tJZgYvf2 z=Pl9*@Z0~Rh5@P5qvKGNW3zO__(byUQ-Oq>#2*?hvjDuzT4P{Pa zyWF^ScUD%XGEgP@^VT`iFz|lsV)o?evct2A-vhMrJ`N6j^1Tdn1iUWCZSy~G*LaVw zZyUV31)jE?)_q>*76tI49gT&@2EOecD84_hL9-Gn8x_9XU{VXM(0c-##G6(z8dlAH zpK?0jF^tbI`JYe!)&zJgS5PL5O-^^5^h+A#3!r`}-P$PuKIg!p^L*yV6fQSFNh@0} z=~MVSobQB3w8#N9({^@iM-g_n_7&Q~!hU)JUl6~e!YEMm=rHkBFF-0Mb^S0KGi@jQ z^GZ#mYm(mCbmU(=>Nd;er^Jj2VV{wI4ym`Obz7T|-{%j}(ToN(C;6ZHPI0R4(jv?` z$Hr@Sgyw}tTyJBlby{Myhf5EB#Q?AtxZ0G#IHJT1fjNIONOS@Yl5BzTwuiVfkh_li zGhrM>X-rip{n=8$ks$Hv@-d^=db&vI)`u}Rm1HNsmtDi6Qu%eae^I8&YG}tmwW{BY zHWfS(12*2b^nXl!j6*QI&Q6Y&W~Z%}8_64E2Fnr}OZPdvWK^9M-H0{nZC8KZtLV1a zKwjTE|9YGLc5RqO)PTSVD#D=h#u|Rgc?%2wWGtbxO@*dHgqX2RTfFTZ8=rM-7`=Mw zQ*%Vf;ie+wEb{SeAWkdC=M=u}v1aANL1Eo6)fjd}SPVZuSYhHzF2XWI;T|S>oEx)% zWwS>y5uHKDPhqS|U5s_!(;#GoM4V0nTkp33F=bSxgv@rUQGbihcMrt7@d)173zoLmdEoW5P5g1BDpEt)Ttv+#B2^R}r1?`0JuK=0uoaRJnMbd=>AG@8L4%bHm5 zQ^@{DP*KQ6beav=r4thIKU*JR8X6&D!+;(sZt!;T{f#y%vhn>QCk*34e zT;3o5=IWIR82vO1Ue4h8e?u|qcjMVsHQ_7{reTDq$zQnKALTe2}7maHv~$?B#d8Rg_iaHug#lkvNP1Oz+No9j%lWHNqxt zwAs%(W^nqgIw}a7vsnn6(EMzj=DsQ)Y{2sSyI>@4(QO2+JYmoL z@@*;YHnW3jjS^ev0vWfwlERvNcZ8bwEd+~2DVn8^ch!jd#Ceo3F_MB3J*St>D^QpA zYt{t63r^8P`1B#sC{0F%M45EL9@|oJkB~2{AmZo3B3biPb7OEF=A4{ zp?)LRT!&@Ao%rgu7mB)QM+}s%}WS)y^B-O-{7!Lxqf@|-Ee>2 zdViLji{0vMka*9LRlJ_>!KVQ9ImUJAw9wAvW+=D)qiA*C(9o36nlsN&7A{CtYr zes}x0mUTSjC|~JN9j$(07awOK1X%}uSnUTy&&4$9P8?I{S*z>tBW)vWt4-Yb-N z?U09qyK}wgX|`vz`$vpA!~gMnx84pRrCGQj%J>x zuAEUE%2~h3jI@dl`vt3tjGviA^GyV5f*V|`7<;3if^ilQ8jlUBKU<)Q$?9h|A)9MW zRhKof43R34)y|~BkL#iqfQ2N}pAq>`S<#0K64{c;OchNh_k$ zMfHOgy;Ex^5G3Eg^CFr^W)p4%P9)!&Pq5aL;f0TZzNKa=VYUM{qbhZV-}i& zw@AWG1DA&?0$ygIp>t37>u>Tr91EUwVx=aBJf;myC9H11gjt<{KaaNnJ#A`#$!TbLgW)daBbqX z8=Rd7vcXC7!pGtj(U-Tz|Hsx_M@1EOZR510bSnr9-7O_Z58Yh@14uU#f^^rALpMm5 zf|PU(r8G*HlysNh@p<3p`~LX-)>(_iVwg2^&biON_rCUZUE72V2WP&{^y6h1*rh9E z++c6R*y&-`mR4e4Lq(@TvqqE#8|%0sV{_-^DEn{x<vU%1W` z|3)`>t+r@q(7fz!S!sSA1$u) z9Z0S)iS2at+9u#)vxw1vJ-(*)oQg~$_9PNL_NeVI9mlx*qh=~YRm-m?ua=2Ft68US z=hgmnJ?#7nzQ1((QGL1--@sNgW>iD<1)J)3i=*Dz3yFly(c4uM&wT2mrt)1eW&upqYbfFQethhpuCG=nw74~Q>ci;|;BI(%|F9p%Mw0RvT zcvXDX%$Afl=_WG!M1=Cdg{Tm=)FFgmg{F#YslB^#KrVV=NSf_n@;#IjU0Ox1-|vrm z?kjM(q7GfdHPyriF4-eVpA%%%X9IfH^A%R{IT)n8sA!wwELG6#JaJ_Za_oqHC@4o=66(F%Dk}fIG_(I_vc1-o+0);OO zX$HY5BW;|!PRG~bitRa$aLvZ&yGAh0tr~yZG?#(=c#q22K0gw6H zV#moDwdB)h=z~c4bFmNa^nN&tgl7vm^=vY%`uOiWxO{v!ZJ+q{*8;6$O2BY%s7;6D zL4@&6|bTF?2lW;Qx(`ldoiA-c&g z+mc1SwiQr+`Q0chTpu)J0J4@icYGO}iiaHCxQsJE$B=0qxN8z0Sk#u(A3AJ` zB*@ZFU`B_k%ZUnzFo%S$quL4z3hv+BgK*8wp5H2wJrY&jFlO(c3D-pYNTWGR9ikVh zfhkHEZGHCFw9;4YuNKb=`WzR_thf=)?vpG<#jL1B=MJbF>Hw3|Dgv;KD;}O72$`Ea zXzZ3dxRp{)1Iz^MoPFs$;`ns4jQ-buepPtbR<8wK=f1Ds9QR*1wHZkU6nP)-_R8U3 z{no?DBQlfo9?GlOQyHeyM!Jdl2>3J-C=0|2qFR(~UW}|rm98nSFlQmN;CV;?nf>Te zZD|vLXF}mHXTQKg-gLh+dG9?NGdupCJOx?#*-W!cRlu{u_{ZRnMxPxKU4N&&Rwp{I zkKZ2DWhRBUQGRJUn5~(A)!Z<-a_>Dc;&Adfk26rElHDkGl>pWGXH+|B9q_W^O$2C#yOkvgYzO#$={hph`W0`HD^8#XLQzTzK?`$mMlKR=7Rc>>~ zHa5Jp!&^jA(wq(DcukJd$m;X7Z@8+m*lvhs=mp3bFI8-KDL$YbO%& zQWmT8%)M~NU!{JCR@mV^W11^V3MoHV2Dumft0AnRG`U@1@jH(ypMIMPv`ydzc%WTv zo&aUCHUn8DzMuH+U8&HV9n-_YXS44qkI)SL0;tntgbSNLy8frf9$mjCg$*qD-Mw4FHN!^Y07==}7;&P&7L!)}{HR;J$a z%PLufl)0Sr$^7X$Y)EC|_2jeAfL3gi)Wf;XA=dKTtL+ZoM3uaU{2S;9bDqxzAkm@} z_1p#gwq!ILO%kOkY0OP)-O$he1B4;NkeBX#oIl=>fek>+A%$@u%tT590yYz$+vJzj z%6~+^9L$Wbee}Vl$29u0*f?;cD(Q2bk~wV)iKLvTI609WV|SeQs{3L4BA)w2;?qIq zao24_7;QorH@MMfJ^AuL2vDBS#kk-wv>!Gy?+suybXseMMjxZRd0f6Gd%QSbL%kDp znS-nB{lbE>Pz#TIJS`O_=QcEQq`L1afb_iuULG42+3kMJ7UD!|1gR3~b5tF(D-M8* zBh5YbZd+SL-Y-kuUjH`zfS_es^8B?NPfhyp=<-osgeY;~lsLq|Ejgm*RlUiJ@TdH5 z`gTqoxii=N?5ON-Ni)Z38;dkmy$uLMHg!mAi5bw~1DNH`xDLH4cpi-fGHgl~==1a< zIcKCvNh~5@Hq6{^*+I-8)mN_K+l0nNyGGkc|MdbeU~9{Nrk0?qEoXJrh=DSSV8d<_ zSbtF7R%70L7Ln_$(!LwSBS@5n{=x8Pm1%o%qJ$8HY>d6Af)_lq*UQ8KXDZQ*AWHu( zWL+=t84)VnhYiQwR^RrlqONTpk<%K*g)*Y za%hjyCeSAW+Fh5w^FNw3*{EJ0Eic{8}%60iK8%iFx|?{;pazyJfc0@UeVAvg>^Rel?)|WJ@w`%thxyMx zTh{C21(4j1?X=L^gN90Q+?AGLq+x%%PnFk72u z+=Z-26R<{mQyUS!aH!AQHy>4;@cO13HJ!^exb8U%mLM0)yQ$BbeK{r{uGvKG5oKTZ zHd0qD6cfUY9IYQ|JF$yTwnlZ>@L=q0RSGRyhAFn)rR>9VQPtH%6~i?*pwzkdCanJC z(%%*>3bIOBrk3l&KYPs4kN2lbO!g<8z~KoC1R;M2Kqb|^NL1XmuPR`m&l^xC6%K}R zq=ndLRz~Qm!%`x)L1prR7e5p}N2{i${j87!)zytz=(^Xj4m;Ewe%U8Biz7xUi7kPW zCfYN}$yrz`zI!KXDT6s;^UHI{(T3b%ICk_U?+tGO3es@iKZ3&_VKaXxkB0%SWKxZ& zj@zAVURHNNvD9EId-cbE;^q|f86ILJJ^-)zIifN(P$W~Tkpr2--)}MYS)Vvf4?;%C z!6)w8)u{bM#}I@k2d-YD3FyTV*Cvu1@jc%GChahZYfNRe@Uh8=G{#}!?)=MKwsVmK z3=7+7d`D3iz1aKxWU&m541KFZVR~E2_9fBE+`KfDJNI}i*Zam#H{V|-j{6)G|Lx=> z_P~FCxUgb*Z3BpEt@;e)H}d!ntm8#&XUIJOH-ct`tHXt|0cTYT;?RET;||gD&OO?< zKx!soBuvZid_kvxJ5jBl0h`YAjlb%7^#B=qxm0Rvt1gm^2Oh4-kN@bI)Y!y-U~S8W z>uVws zxk2%%H06kR(&P8COvz2d?}`472sf|Bty~dtPvXusX{{R`UpDg3(7PGDrlzPygi^g> z7=O^nP_FYCCVa~Bk#6gK)d>xaFag35(SHITd;rHwTZ&~0 z)fgW)578_vQf8auQ}W586b-Y;s%7$_hYtg6q8H(cj+ybLizZQMp_wTxpNSYYf0KYq z(G*}3Ae+3m!#6AdGi~d%Kaq=skykL8hJCT?<6mQD)w>IT?)`8gqXVM#+btwk7w08m zl<-w61QH)@MI#<800H`hLfJi@>Pu~{%OpmTBE*YY>+)mF-gahB(Bc@V&;hZ9-FT&L zwy@*dabD(@z`hV*BYeDN&U3wUYsIsg#yCZB$oBSMy5AaS(UXp5T)2MYXr&=khkXJ=~-VK{H~ko%o_{q(A{I84&gS&&`9WKQ_iAVEVfmjznnY<39?Jt z>K=M+2VYn}q;Jq)Gh+kq!U;TRA*-6Aja$ri)oB0o&HrGohWzOSbw?p1pBFy8IGjNH(TcHPZ|jiU-)PrB7pY455I;xUL8$o#3yH$McMtm z!1(r#qv^T5RZv@M_J+afL4b20Fd$Z~DIFbdO62cOU3NF%f3Q8%*9F%-<>6`I?Q)ZG z7P3#Ce1GWGw7+4_N?uBR>2a`W9$gD>lC?@vN{FiwE=uUaVilSdSR|1W(tbK(0jngI z6uNkYUA#Wi@NH8mN1*#Yi$kB5eU~cee`U~w6^d@I0!4xm#HA5A^0C0Nl6Krg)+-Di zYg52N#7@6sd|p=LuLQ zZPZ@Ny@JHW%I+<4r?3Rs^O@~cIeOT_&rGaS@@u;rw5Sue?g2M@9%Q_1O^c}%Xo4$i zvg>bk-%_cf!rBi-<*6)?+9Z=ZEm?UPzR`)B-|i!>CR>2D;(Su+b>mkHW3ng{p)?13 zHY8i{KC9wU<5`KEA+dT$oguLf{8hAv;KtKWmKexrmhkF_6g}JXl3@%H5g!Xtba|Pe zYWW`3xC~yxp2gn6&oi!^ zhVe#;mXKPvvO|&xn)fT7uZ-@DtjsoBl?jJkFe*nKvpzFnP{kmJMCP!y6Jn(St-2)k z4~yAT42fKPHB_adBm&R(dSu3$-7mDRe);)1nX=~ed0zb-o+?nj_=cGlTLe0fB*Ge@ z8ItvFrbWzrSZu9~noD6pM2I9n6VdsvgZ ztu11&w@EAmNoLaR)H68Akq*vtq!pMFGhL$VID-~vM&rJZLWDo$V$?En0%HKEDSyAC z;*0-X^6=4l1R~k(tiCxUgSvY8`7|hYIur31z=D+KLV*1pxb8LuNyqODqBfSjGEzyf z#M~Q@rQUx`o={_P#68t0rL$hD_V!-e&ZyLe0h8@32oeWn?eckj?aaYOr6F>C|nZz~8DB}SHVrKhTyWIr`>X49}gFH^$8qG@Fr zQrncq&^hz%z_oen{Wc+4VuKkZVkrXiovl$k!qv{ZH8`&PXCMQpKCP-}j^?{QW;mA$ zp{l@>!^p+irA*z@$UoMhl(&qm5>-R8zhNFfYv<9L3>*nhBcP~Sv641LXb4G!&Gi_a z>Xq8n3h?1J;ZT8N$G==VPDjwt-U4Q(cr@Y$nZry>aw|TY@onnaHnN!GIm9PSsd zeOx;a{cJdi%m|_{2M8OxdH7(2?ioZ;`AJiVEuDyW?vB{iS!ECs)zr1(6P^eKBcMT@ zcYpM&P$rUCP5ktyL#^+hH2vM+WcVCa8Q8r|i$Ge0FT4aje zB9gjqUmX98U)xbJNrf zhFeA74Hzz_3|(tQ8{9BUA$24@b1}wLLx>{iyvgg~6!uT;3k(DG1{gs?5abm4e7}8G z@y4X)>b>Ri$<~ENo0GMioTa!p(I7(RO~Y5*wPq0w-I@)93_J#YkkX-~HkFc!(xTF(^uPPl~e-<%D9{%wA9@6Xj zr2*q8f!+Q_7uakHS$w@qe9dj_q2IsiYawYfoFZ**xApN|CYxL-8J-zKvBEhhSxHJ^ zs=fTI+L_{U;I|e(F`F5maiT1Atn{(4V(fts*U1z#SPz$GFG{=)=e{oV|D@J;pD@rc z*p(2li!beO8e)*D)lfd-|70W2o$ zk!}F1Y<6D8*3V1(S*r{?sLVX^+oP{eBK(!BF?RK)PA@*(9R+8u$ztjgSM!%ZehrJ1 zc)_yP&y3t9ys)weJ32+})ND0P}sJsjqe zsI}CzOChBve2X`F32jkx*r8%`$7mxrnrGJzIbDJRrRBz1VVF?Og_IMZ6EG2^17P{UNwZh%h(gnB@X z&Bb03nPYSn)$ej-nT`y4fsU#obfso0>ku9sb{Tj)Zrr@&)Uv!#K~z65-XBd2%s%>S z>E(U=Enq13{Zn^OX#FO?OFrl66ruk!sX#kmkUeWTb&4E*F0HiPun05-d?Kv-f<)0MiS4q*?I}?x* znM%p|PWrFPZ{=I-;!l1LdC{^OW}%0>P*C1AW`q6V*oKk!R4(ngrz~KXMnkMlPUu>& zWh6vdpom$VO8! zBX_pf2!+-O?T>!RI|H=cNdF%ddEck4H8%L1jq$x^0p|iJh-fN+4 z$+i}MetuafrTX|AP<$6!TA2@P0cNz`W8e}QbG{Mb1bJe01cS_xbbMpL{gsCF?tC3B zRD7RM-Il??@X3;95wJh9bbKi`rSvC?oU=;Xj@=qv?L!YS1{ck4$_8Po5tN+coool* zdmk|F0D(Ta7a;~7S!(o5Q=xeh)8TSusyGsl!G7uCvm~5wP1TXXMpBtmP+Uw-h!Sk3 zS`+o&K_-!OQ4*+N0N%#JGH@q6mc8cJRTaJrW09K*T%7dPgewZsy=_hE%;O4i%Xx>w zk&UYX3e!~D(?KxNP4lT}HH@M3r6So_!1GVa(x!Ow#z@P%qVdx^I+ltf@m-pJo8e2` zbf0UNKXpK{${GXD$Ax{epc$Sn`P=wq`&SX}&{+v@y8l3f{lb-W3Rk$_$4h+A?G(z3 zRM*>_+ceFoRHV|rr%tnAPkZ;BvD?Cql|OF3)^bgJukpD}1$#-X2HZ%A1Fo&trVBri zel1Gs?HP!t@QL4&+I-BrKSu!*@Bjs}1#|_SJC+yKoCAH7doK3KTqa*nM!CDfoZ_hj z0mKq(hTToByPsaS*m0qr&W_%%CzOF#+is=9*N4YEqt=wcs1!YYg<%PqKH(-4wMO2@5@N=m;1Wge4=%^3Vd0DRD+h0sXQ z!NMB<7riQB$mdMM9V3FgKZfKTleG~RWf?j~BVOZ}7>!xt(!34v;LhuzQMloPriXv1~qH%!;V;%6DBxT-#!eos_KOszvf^;{occ>{8JbE zhQ;7~EWu(xaagb3vRUZ;w@u6E+6g~YQm0(?fXRD>Afd%kC5w_gRj4A9)Wx703%8U- zr~xaVl(mA2&(_`M&Um(Fh-@0#d6R~(K(S*|_OA+=$U6~{m!e+M+;5}-53rvZc`ILS z{`~YEAZCk(0EP1Tx0{efy&AK4CMN!5o4u)0n|MA322_3)z?sr-3#P!7^ENt&Py|H} zKOxR1H#RUWd0sl%SHq5w0Nf8R75n%=xTymo!Rlb?yr9V+6~?lm^DggBEM(x@P@UOsdi__ThD{;|~?GNn8@nOmfwf7raH| zW=?ilAfe^KXRJYBOegCW|onMY)!a(wXh$O#v$vLO7a>TA`r)17t>8z zFtR4yh8g@@U)Fd~49yQ`*h34!!-xEwDL+dwNMnnJS_neN>VfaEeW(PdL>YGcE>%*v z=;^--jJH z8N0!RG@8!+*ZpW|)?uzg(rPf{RH?L5vzJVpwE_DKFxv`>1>ofC_^x-+!u%;gsxh;n z45ZzTXP09g0$CPQ^;kIH6@=gaVC#$|Fm+X3k>CLj7b=!I<2iU?}>h?JU@^|g@8Am}dRS^PpatkhYfHTO< zuB^t&PNUz=)ybz`6YbOO;?<)>wiZ$jo|*>z9?OfjONOfjo(_;l#+3Xf2y-uC7Aa(h zzJeC)TX%bHk_CXi0wT6=a^)yfeBt8*M_9S41{!bcYj6rqmg2l8M=qd4`sSEEzTcz> zRxanQccYigg!s<1hGgF8!P1qgdC3-_ehH z5fP40p8xv=eAY-fg%`PrF$&MDgDR7_J1?dHEl2_p$jtr8Dii|`<2g+nRwBXO1<>l# zX4X+@+@FIPKe;pJ3-l`+YUWtWbKLt06~Z2!sv{Xv8EWjX5dR`BTPCs>G0ABW3`H=yHf=CEC)F<}LkB${0#loM3&oD(t638S?Fc z#VC~LaDXj7z(DJ^e$M9Mk{tkK%gZs&0A}t`)97M6RKHFFTI&64-s4D*_ z>{ho;`O2dnaBMnf;>$NL&S^ED^_b*|Qc}WC42u1{s*@*JRzjF643gW}x(Ox7xs6G< zbmNn8QUSBv$3A@VpeazzvAmu8)5EP!RzKdy{5Jsv%zm26HGdL1I5OCqONA>BidRb& zeDZ2bl5@D~1!-u?R+VMp`6VI#d)*D%2Pdub+^GRKoJOC!Wz_xwdzDEi;lgVozw;&E zomhMc6%yksPV0j;vb$SaNQb}*!AJDGmoWrL^=Y|to~iQjk{wZt8mL18BgF|4&44Es zzQoDe`QJ&De5XSRpyq>><$eE+f3+g}qj#XSVj!=+6k0(Id zQ{oItaPLGXd(*j`DxR$dwQ*&}CMBM6G7+#($vas-f32O7D{lLRyNZa|r z1F_`AMv+EafM}@=FrRQt=+PZc2UMvde?G4O8R4I~_#Me2j@`@$#h&y7-C%BzPH`Bw zN-fBGlS%bwvs1g%{ncu&=M5g6&u^1}HV%24Sz|%aezsD# z7{k|1HAV+jhYnI=uYujInn}#c-+defDf@R4yyFs7KjfmLgd6{Ge!mfemeGbUl`@Iy zz!grWm`*x1Juu8dfzgg8iRE`%)J0OsZbAML;rE)p>jxz$Z=+-p_S(O}E&c!k!%|3f z2aV;RpvD^|u=vral!%AyS{o&3WVzPqBv<=TtnoFRHKmbsTZE|MW%iRh{~jQ7p8w=_ z9xq%(55{TvdzR6K3p75ajyBv`6pl8Sr6>Z2rAUXRs9c2Hn71vpO_@}d$1D0z$8+sI zC0#MCZpyV}FtPdFBvIj)gBlddzMQ!qwP849-?U%{Eza2wcb55*;C)oKpo<~pnX{n( zNumTy^ zl2U(++}4yRDh&l~pqoYGfJvkY3i`mY~r-((z>h4EP1l^Xp#&r5#oTXxxUGkweO#v8XYGSq(&bTtASMkpIjjchTp(T$j8iA77>YuYsi_yf0b zOIyvb)p9#$F@HyJ?07D~MF%a+OQl4E&h?}^b&L1Y{XnQL2od7HSTTei`nJfTar97E ztwt664nyXFhyjf0@+VE{xfYvt)Eq*FS3@VpJe)1t8b%l~3ft#{2%u?Cp0Z1=K0W^R zJ|Cx^uUl<$GX5g<>jlc9^zp`Z_NO)EbSivX!v`j6&fb&uO*4;rO#yKzS&;#wAU0DR z=hB;_uX;{D#V@*1xmDshm%crmjjD_N`3&icfJkyh{6T~~#jT`y5DwA+#dQ5NwA)V@ z+Q$aVDda?d~%(KGm&mxTY`&pm~DG1gG~xhl7%?QlBYdminM#|t`IMbP-H=!>kIRkt-NOD zVVU4H}3QsejZf5!oC*r&yf`}0-ddr6COy2Ap4CkJMv zp9G@IP3C|>papK5dQL`GU6xH`;LTEvZ7jC!C=ohhjZzEhA9Ag|$)aZNa4=>10l}2C z^ULQ66}sv)tkg6?*m3HlbL__A*`$d|QrzB)HY$|p+aU`No9@0Uqm&b%xo(wtE{bzF z(+|N{%rw(D7iD~+O{CE=(7KgQF#uBOkc;&{?l=d6b8m{%>01flPZ}`UK>*qF`-x`# z4*=gDk(!!EKj1V-HRh+_iwoytDypA3PH^UP_>Nu6Q@Qu7%4{;8K5=QX4s&^&We6W0 zw#s87uElLl3ClMRxYbbIC_FKF=&nzDh6?Y}A2i8xzhLn(z_L<@)7=M&+ZHVceF5j@ z6RB|BirY$ShjJw(Qk-N6Hj*m{Jy|wZa(;qLS-h6yLmf#-0S|va9#&F&*8B6zIcqVl z9z;ibyMOgSr18TVJXjDhFfmmfv#qnV`mYvE6hFh1d`YA`yx=*!GoDB*SozjDuN>Pj zx)yuGFY9nowzxZz_fO2GaE4474!mOxR#Oq2*G&+!goStQPJYS*&n`nEwrRezCO%4& z;qj-GH(i1Ts}^DMi4XFVhNYUTNaF@7=yA`98^Dz81a(eXi0!qKpfu<%j4&k<`X1tN z1XWW>NcI0*pRg@Ix6fmfR2k9clj1ZKJb6jO-v<~jwbGBGR%fd)K4taGtg~VmeGOPa zC;#j`d$$7?4tQfbQ36h>7yk?RxF0UHHsM;3r^<-9lZ)>{h5-s3D=PbZ&GNT^>^OLx z%zjfI&;HM~XUaxZt!VJEIi!-2sV`N+PRl=EF`!;!RADQ}OZO5X?7EIq-sgUJ4~D?M zHP;#D5=x>O3)TMo=EYF16XB#9j)3=PH=uj4~gDSGV(y2pIYz zq7xId(e?Dq=*yH4gH&rx5R;I9cm|LXSX<8we#pT-F`(fJG^xk@yvl7tmaSg$h6K@P z?&}IL*354BqcjzJ2Ya^H3oTzatdS(Rngw*d63St_``PA(NRFd z=2i7auSNc)6$z0^=JuG=uV2-_QbUI}WWP_TL(;~gb?>0rfq8$7gK;)}u&a0F1PF0N^*MfbjvQE$GKE1}}p?>J@PbOuiSIl#T zFTML@*sBCQcEX0lw5EIkr?!zfHzPOf9}m;Vf7v^(F1+zK z%{AKSVEfp$cUk{VCU?M?k6%->rm+;#5x?)a`W<>tbm*z+bTN|$@dR-|V zYt||@XNCX5Y5tp|`BkLqW-p9%PlEHChVlSnwKzf4zGc~483*E>r~urM3p>Heag;~0dYTZ`&V|SYM?v{E!Ta> z`B3!-LVT4m7lCAELH3dEMzD2T*(VhV7!PVs!DsqM(2GP8P*d78HdE>6@QAT-3JRu* z)E^X;IC2)e^Sy$>$-J{>x~|r#Q0=;r?z8#4Wi!I!wKMLU{(H>7}xXtUMFfE zU$$_DTDVTrmg|rhi_mneZk=QFVkL+B-RR&UXSgLm=im8)F&A!;lNh9bbC}LCCicOULlX9I!udO1{7oLupv5(H6lutpN4{rY$H^td?%gd#hi}RNb}rhk2>9Y zccl_6i^Q__KVsPPBy*exD!L9uP!+pW<@?@-Pv!w0*phv2=>}0W#Rz7>@g+8CeB4QT z$K^J`D>+@Jw89E(*a8RjuQM8-lb)eOs^wSQEoR=$cj7T;tZFVu*1J@WB+?m9o&j2P za^y^^$^Ku(>wP^yXb4GP$6gZUuxD-qzXElh*K3yD%=#$c8MYG#Z z<%J`~P*P04^vd)D!-0k;lK_}9iOYPfPx>OsLYf;vEl9k>LthOS+y6`QHzGd!&F#B6 zLBN`S=b+Bk#^$JL^38&?|H)d9rlB{N36O7^4DrV?K`odONNmY)(|6AmmtEjU+Ni|7 zGQGHQsh!m-83gT@T;E1HnHB7V_F{5;c|tbq*$$>SKGA(Db01Xro}!+DoVDN<<#dvC z;;%RcdGV5zg8?M-u7p94ty2I*sc82O#hYS#CX9&*>=_@qG))>#JIfFl90UEk5X@Sv zeKi~s%j2gwyOFPh#bBE_rwNC^sYa5_D8FqSnJ`K=BhfpraHS*R&WP>HzUaUvSmNfRxAl1O#bvM9O)?4~A8me-6Jo(+(xPk01 zvXl&Bb;1?nstlWKeyhS=d!w1HZER2D6ro)z%t;8By)xjt>1Hvw#i|#Cp@NIGu}9_L zI*J8K6s=<+djc`?d?os2J$?}uc$0C#6jAYhX@lD#x5V<6qTvw|Z8xgySBj=%-NU;& zj>)OrG@~OhDhsU!U7~ub5c|qk#aq+RC`-ndgl{Tj06kB$A#I?~>~=;my&-_n03&_O z-Y*J2>FKXIzLC3h}yG!KrP#~#3CcrF1 zEkh-W(6sn8Oa`SkBEN>`&qzno*!3PZn|~es{Vl0qV-g*+1?+c$0C;m~|H~=Tz2x!& z^FP?l|Ib;$kbl$5QS`@MF?`y*d4*5V{-XwhCtz5JCU3wN!lBQTwUe>>+41IM$7(U zKROe;zPKZ!%onyrf31+wI&iPr55~~9Deb8%AP8P|5>sIzyV{N>ZngUG@mmDfz$B_B zO15kdKLtqhBnSqUt&bU>-QubKyhP2bh`%zu9W!2P%jPgvHDB9#IY@ond4DuprDttx zYwp^EZky-3!=($u-Bz5>B)<3u8UH(U1yS=qKHZLG^5csg)&uam{w#0UR~ZcH|2-;o zjRD!>RaK_AFW(>iX_D~a^GRAks>!Ia;?l)4?+oiq5npN!WG(JfC8m@p3v2f5Ls9xP z;35R&d@6wD*rdSt@=BoIY%f_cw)HRXvd8c#OcdGn&u+q_qnT;ED{pH2^ptuo4q_rr z6^W~~&J;3C1647+W4NBnvz4w0=u-(n+0dnkDUKM$g((j=-1Zie(A;}0SW0m+%wWT1 z0-BCQ@{MT9T};V>P$I%HJsRXN8X-zLMR7&<^`X&~Fr7U3j{zMn8XZu+|9*wB%QwpG zUu&%0?^k?(@BS!$%n{Re83$NoezsKV6w3dS_WvA6M@sS6U{CRf@hOSaLtZ z{pn%vq!e8%PYiI+I}gJ-%qzjBB|2OsY34R+WF#vCxDsZEx!Ygcw(yZ25_w|2Z)dQ4 zAN3jOdu~wS17H`!jU*)aZE&R?Mea24h`ouXqqU<}M2+AuR7ADoLYk|Z!a=`3X&C*n zG`Jo|bm-Fv(vcTdnRy~-s077!luBxH_6Gy7l~7^EH z5)iLCfhrQ1Ai5~0O%U-CC>;MC8fpAN%yMrYFQzk0y)zxON*r9j(hiP#Tp#dV>^|l^ znc>&$yDyKh8)&J25b^2yqAz5z6axaW4H{%A~9{QS9n+4OMpsaRefUx1z{_L=u;{&4T!> zC`+R;0$tj&{f}+lhmEpm;@=-11+RX7d%8P#2F%$p6}w*L5p)0j<{QNGf1ZVQ6O5gS z-20P0d~Sn!Bz(voaPp6Lm+xCi!jK~YTHwD!XB*c-YYXpm9^inQl&j^gGbp@ySg=BU z{<=Gdq(VO0*d&m#!@~wD+oLp0Q`6b6LV=IOcaBS zr-;eyY+4}^|E9v@un@xlG2cz*nNr64m9EJnqN++lE{=u~8ktt>#tRP%tIgD)Y|9$i zFcsGeu-~F#$CZR&un`nv<2rW?phAa_cc7RY!I=3UCyU<&>>i+JKOCv}t5# zb=$pbI)8p~foA|YBDv2N$&odWLZfb#9+!ex{F_vU_gONO7)1Yj`w0FC$huAetGxwD zuHOnTy>x(8E@SrJfFO-gsf3{snY+D^ZNk=3l}Zy9vLuiI1fg!O_E+H+CL-Jbc8J!e_NZAXtOo<^>LF%*rj_)R2-t@R zG&&%K{C5-}Y>_0HSKLQAR|D^zAP_p5CQ>Fvy;}wJlqb%Toma~yumY61 zPlz^!<>Q;Q{(q3@Jpfg~*Ds+2=?2SUk=Sm`|@mdbTd$uHb8?oG1OCFQC{s~Y#k*8PxUA@5i-NXIxa5@Cke#;G3&*Cist^uIU ztD9oL?)mRO9b!`!fMM^K+=db*p9}yW(Jr-|o1^<bj?jkrDBjwGLCLGOusF;8MW4BncA4 zwL3V&kuZa*HGVGQDrHqB=2~M9&xpwfH}df((gZGrlbgn6Ci^^pLqq$plkHO0ZLX^# zEz(%|qJqHozfpYyx8OzetW#VzpEY7*AhwkbTmZ%9#2Sw;1|n^Pe&TwT#7cobwFM$BRbNQuDjACiWPGOLGaIIrSyhG=6ahAP zbr$t;J>-cnPrT9Gx$%`J?V<;V3SM=g2)kq(&D>qIMFdwVF8~$(AlR2DY5&YlN zt;|KPo_PXTI!}PF6{J)sK5Wk6xbu1iB9mPi@J{A`a|1CujC#X=e4e`2oh(2CFx6CWZj?I)K$AJ;)RiOwfa^{PR z#8p+-b(H(g`VGNrqMHDl8UJR7jO^@dKysE3SfHwP*o%!Nr2dvk`frj` z>59Rmo)=dxf2QO_H3BdKTU^(B*~BXe{Qh_4bgTW_^&1?c`gC0HX|P;5V@$ZnXFMdy|d#bR0rhpAKhPi(_P_~$vITS|`c+oJlvDjnq<+1MLBaR&SXxT0d*aOQ&QX#MYcMtT# z0IdE*O)7ouL8Yj~pi_ccp11D$k|Mn~|8u?A3^6{D`Cc!-uQP07YL@cv4dWk57#|<) zdojWyT-8WMfL!j>o>o|-C_qSug#)D^K|OJv^|9T(5LG#av4>O*0Oj50LF3txAmXdF#s23f%i zL(oaR4z?!Xi~k-lMr_K zDR2!iFo28%*z>-)tbQvfkm{-QMTWBi@K_Qiz{o!gmS_pg4YQl7F zov@+1Q>jnQ&;r_q=A*M!#uWc;1}Vily0vdGKUmp4qBY?AHsu6mfI zFnx8wI;V{M`NH8@*ti9mO8tey)tz8WdZ?!%Zyq{%@&%s|T=w&s>yE%X*BRa+@4Y#v zz6d;@S5K>e+fuG$0+dC;FClOj`#)J4pd{g=>{fC*l#g*01El)}Z&4`$3p~lEX?01k z>i=E-(inj?xNEKS3tuC4fCj9~EJ#ORWXW|u_asm&n9Hi;4u!INUU;A&CuUx^iB; zVs&q{4nRm}IGfhpotreyOKnw*Y(MQ{En5^gXAu&{ zfSnBnMHPcJ#Sg1-D#0uv2fcE&L4EUyd%M#O;Ff-MUT%9S>Nz|ykua>03Hie2;Y!#1 z-z$?8%tBd^_VcBfw=Ac7cM$SI(sMnQw@-j9e-+?uI=AhWeSA942QGpATy#jiP!-+HNf@*TmnF7Is7-BckjFIU*_$3Jl+1L4ZK-|UJ7Qor%M?v!S2jka5EA{n?+O4m6`D53yDJOKryvgF=^1#u+(&(Pg{-8+&S z($Fsz?{uf_i&7PF@cPqe!xg4OLCsN+8_0;fTyUmdnS6W{QkfROWYKQdt6@kK@w>lr zN~t&SKUyl!^TRfu0jBg-FH~CcS)k{YLW; zG~%IBf95+pVIXJ;TwArgI5%K9vaN{7{9nJFFaj~6k}@7_@p(Kv!AmQ)PfSF4540zl z;PKr{R1y96B>z3Ww=Lh>!;T!8-V!=&)n{pW7P8I=-g&^r6KB3cNuImVdbL3-`-kd3 zGEM&A_fkXtumT0XtB-6{1_p2s+dbNV+xgGQ*F zy++1xHD+wvB4-!=XIxdBhtq{B`B07+pIP$74cNwIxGkus^yrEQ8zQg%(v`~&&C;nw zOgqdj5TXd8beVoE52&X}Rf-I`RV~vMmm(?Qn#Bt}>X|E8Gz_ZNC*DirVMj}5Ox8(t zWUYUR$Bb7d@k3SWN9}eu%b_zQdob|13HikJu~-tQ3&AOY(@pkDBv|Bv+Z zuYFS)=>kY)0sv_oa6xCcm>a_7R3QUgmOhWFfq*WNKf7D>fZ;M2p@VDK62Pl=-PKkmkx6Jv9M14gL% zgOFJ@1%XIu;IjzB|FHIxE^&Q0!fe`>f>NzlNasNA&hfPR(udLVr#EsWpDzs--wuMN z9GjF7vFjZD2lW4dPUK?{473liNAGVJ@0X+q`{F<7YrMhrl5XL6dkslfo|@-UC;@U2 z^`oVM^gthwRL*Q?D^^4i!6)5W(Unk+Xc8H3rA+}U`ujaTu}Z~#dAHcN@tEA#7`*@y zSfHT^2>-qR&kG<8d!~vI>hU&3ZMbj-gL|+`M~xIG?snoTovg&1Pcb!++BMlte&lD( zhOt$KV|-s>?0y$h08TAH3e7(u&rJyR1E^RG-Z)GkS7L;VLTQh4PB>cag~;Tvp4twhWo$S``9=Kh<2k|-s0BEDn+(X%a@2A zsLjqtLS@CMe!!YV1+>vP(399S#HzBNFv{ZQU*y=#OIUtDF0&G#2eD=u4s|ChN`_(d zFN6$=d<)7Ba&&Z5ssAKHP}y=^%4|50gA#uG04JJOWyQ~yvJv7)F>(cku3VA0CAr`6 ze0~C%gXl{cv?Dz(ah-$=r-n6qDXZrI8hs|0@k+dPhD&W8w z9QFfHRXp8aMAOUt6GGAG0d@+8t_{RlOd*sCs6+xPi@XF8q+Em5x7IDj&pGNP>Z zL(1+x{?ku^#66Fi$x?$fFr-i3X$eU(|ViuR_Ne7QVWgkzouv zN!D?P$36WO^$9YfT{nhYfnb>k+g(vsY&q(;@0s!UThSPw7a+mqRD8GG@nGWp^}^>J zNYNybJPi(5Lc=co@Xvl|)BS|r2v07B6uqEF>#$v~8%nkm8mX31akO_=0$+fxWWR+3 zp199N*a+K^HI}GC<104b{AUI9u24T$rcm@$VRd}Gj$X#QnkMiFEi*e9Xc@niU-o0^ z@?DRy>W8P6r)-ABUjl;@RB!hCi>cLMpIcW67{~)Qh((DDJ%+KkQtOJ!^t7`WW_#Nq zdWZP%M{T?M12!GX)XmZCCwheYx(}#OdJ}+maUfEH5-?y5PliXZFveYQzmhye(A}3sp@mplvWR*LJncGS@N^3I3ojA>7hdtf*P?`5Ol}%)Tik6SWL&+-o~U z7Fz{ed2Gzg&d@Bh3xlqSm3KkePd8KSzyh5>N)8G|q!FxJ7Bv12cW9rieS*2Hf<{-^=p&`R@|91^4p|OaD#w!%C)V5A>YL?Od}zI0(7f^l zt_dGnJez(Lm;oke8|%04>o^vpj;tOXy~JP4ve9l|mTbY{t*bUPl0@iF*sfE75_~1d z(0jb+O&@@59&WwJXE1~)4TWrqZ*@LAnb{|jD8%F6t_Y~5b@=cDP-Ibz`wrXCWO2WF zwh$THARN9!O6kSqgCZ~>c&T%tB=;nZKCr`OU>^o?BSf(oftLBULrA|fkkP(~cSG%R zofof)2g2c`n4v~XaJk6 zp};{Nz*_Hr|C|$%at;HUO@Q?tPlNfl#7kgxCv=O(xcHzrBT^5?&i*;^Ur3h)fiz!= zGv}{k<;0MF%=2Zl0ptGgq2RD0D zVd{Uyb7u}mp54y75ZOUMSNg>KRrZf4a+yN5Kyi>C1}pfy91IxZ(J{6Zuc8dcO~gS0h17%uQE0;ZmfyggVfS zBoqsZg!a(x(`Pxam=7_T5aiwbK1#wCIQGTTan03mwfVL^tv5i9=^tx4$8V44n;;U9 zJDx2_6b!S|A1@RuJ4aqEjOMHe3lU5F7w!lUX!qkQ%vmVji7Tg&!s{~O^E%Vj1U>+! zIclGAm`Oh4P3R4LCB!IkPVtjy60PX^X5VuRL~wPr{iOQ)#M$|Xd}lF8BD9Wi0AVuP zV7o~OnE}j$5XACmA_2qcLu!Bw%#QHK1{x*Y(F9H{wPJ~w9ERRlci-PGkzWZEitqi) z{x^98kxEfv5naAG*ZPVu{ayC{B=9K(p;m@j6bPA@QNnIDxGn!K;z$bQOGDu_19VN0NmIqA6|l3vXb2xi4?M5sM!gKv}#e%tOvT@dx@hA?I1XHH7>i$9ALMX zn!Omy)Nlx)a7#i0A0&aaAS`^eZAy##P_s8mOPnX+tsH79vcriS`8?bivR z_hBwjtb&>N-JJw2Pt`ozV<~=^IDRp*6-#xH*^wR!uzV$WZ}M7Z>{J`Btzo%0n=ci8 zzB^lK=%1YTsJGVabtX20_T7kabVgAOqd|^J(!qDSuH*Aq7&U64VEi0n zCk;ItLZR~CM-OtY?WJm?=ag7dgfdL|ma73^M1tA;pP6b+TC$XeB-p_ykxQ0j=)VS_ zfYo={g=&_}@;5C{sd-%+R(~Hb@e$qfE(V}5(;^#7xC%d_12abOcS##!N7l589Kv&eSLtDHc915%SZ=Q}(gc1$CtQlqKrJz^CjIy9WFk%wcM$J$&LA*%0 z8l=sSCqVqmjEYjI^10!AXfKV)=c4b|BFzY+!;0}D{m8d;a~M~dDHkI{Xv{egLCgGV zAzm|5l-j`Tsm3jjP1p2o7_(k-j+t2(b#-4T`iYvt+n^CgnL8E-#Pv-i1fieZ2u6vn z`VTg#_%cda7AnsnaRRkst*8n^kOSF%41;>jM}570Y63G^c`+#zYVx9D!tE~oNRGH@ z0xLzPOg3}$34{WuAQPn2jJ$#1>K)g=-}B}|`sY#%Lzcr+*h8zTRdIP!dcRx+;Ha_* zIsAP>Bpg!^6Gr3iZuWa2K8UGK!^{6xR}7uM7OKIar3XuhPyA97Ss52Xl-rX$95LKL zwPP@+_BU$Bzi_jdFOMp8C&Xx5Ek`KA5p^)x`Y)W4n0-fXK@PF7Z#R0s{Mmy`NA-y&7<~P>X89nM87;Kf^6B+68`Yur1=Bl zc$Bp!GXe`qTf%Z*e`xlyYNba1j>SS9`~qTCBDI|)T}ny_u4Tp_dT&_mCfDg1a+xs( zv+tgLs${mEx2oUHR>spfi3A*WuSmZEhO!T-e3$3AP)Asd8aLZ@=GmT4)Q@MYW%pCP zs5`a>G}tFl2L-dn+n$j0a z&xMp%mZ&1@uG;~lrqx4Bga&#CyT;k~`JPBf8iQif+^rxMGz9Bn!XEs7-Z^*Rbt_(X zG_spw{#@*jG$7+|@Zb%Pi~c9AvGflq3kLb0gOFt}^Wbfig&i@AKqRTspsH^%#gqp8 z58MYC0o1O<~8)G#5Xv^1qtzWBAIiWUiv662rtM4CRJ{}Q63>Grn~^LMABmrvd(l!b-HRYqc- z|NJ2rt1aL%jJBc+*Mh4DgMM0UI`WFjeZ#I7nW|@-78k+4iCJ`?Sr{E&VkOp5%pdDZ zDJh3CnXo?%4GL`(qPNAsV*omdMJyH7Y-}N<0nAsBQ6OEdrB$wsOT-XfUcalpe;P#_ zaq3hT7J9fLq`ZcHWIiemBe2tx68eQ6{NdIz0!{)=0+|xeA?g$|niaVSL$b<(aF&o+ z);!9A=JzDzme%<2FeF6u=PfgG0f2r>^1*l#z+E45x+XQ2BvzHe$t0ZKhvlae-evhM zclnv96Mws66Dj6ApM~hrM(nuU*62WwxkqnNHsWer=V-A~kO)FZwZjpuOMg{~XwG0zrZl^%3bNdQAO7zt&j_9bQrn zmDb&DTVIyBVeWkw%SZ9!W{n#wO)afz#f!m4GN(cs+nfK@1H4iiYf3^**=O@3nEl8O zPDcv)=%Vd+6EC*#g{XvY;uhq4i@k?ZSelTegLwTP8?1WxFsYlO;QXVpJ8{Z)6Ke43 zBGatx7wwm9y=QZkC||}Rbl0vv2)d|X3j6Gp)7S@d{kMPu4nD>YRP*|pc(Sp*Jvyg= zDV5p&EME2T7p0l+ufqCpj^|SeS2;6*E$9?$xmLkoj6b=#Uw^Nlsea6d2u3Fv(J6K{ zpk}+3US%5zK}e+g+{t8%a=vTD^_gF>r*}?7dDCp1>$x8TrWzILk7J%|uvApaAgm_0 zG3umFpNrPqB2zi)z7-s#F}tg8m?^ZC#eq{$_wAc!sTLfR1A@R!BL(5 zuwFbJ3=~dyyiBh`LJwDEB0{T$gO7%SB)ovrp-HB8uqxJKPb)=sk!qC}buC zHQXf3ZeLjOuQE|LUzSwinfh9+5>S{cRTu0aMJJe{(VL-RSicz8g^96$;tTqC!4+%7 zIDXK-rb5MB`DJV^TzGm`Mm<29d1Mefzf^#g818>=qcMp2uUBhoT2LrFyJ(I&WDiGo ztr1Zp^Z41M73mfV$q3XVOxil1Rhl{6#K^O=bbLO_(0~iN7!Dpfs@)7v70Dp1*6Pep z7T!%SWV_Bv!#uP7##l)(;A4CPnf2k|(eh;TV*DV&>gWwf^YmQ8SFBxD16D4MFuDg+ zg{`cW>oD=ejpE$0#Jk93;YHC|geX{I0>|7s9Zxo-jpazS{1ExUD^pU%whSOiv52~9>e-Mvb7%Ax zVzt0mAG)~fqYy^Yn0nn%C>`(RlyV_Ki42rf)+mW=R}-_?f&&d}I6~ok|FLB2k7%KZ z7D3ZuBxb+)1W-qvwHR(V2*b$MW@yjb;zWjWLg#V{dS3`aFPY?As8P}gC21`Aju~42 ziAvq~h*FA6m)gou2;$A8l2fW#>dT@!CPtA=Tcbq~V6V5wL_FRy4; z&Sfd|EA_l%=r&qX@4p|T4ypl*+xXZ#`DM_D>)wo(rHIIgxn8@(qRF%SLU@U7%lzeQU6rG(G8#yLhOMtT1d=rm-_H0~T=~|<0Cv}nhx?yFz)#t3`m)4WY zgHdJ0bHLNNB)$k~-7w`vajQ}-c_I^)>r%*c8_A7)-$3rT45|Z&v+GIeMcHgCtO4>z zbKBkTowW>Jl0G6;60WygkFU)s6%AE6CNsR0CnJf^F1e%^CB9Nws0cF!6FOrnxc@~I zkrzp#*S;H}WmNFlPvvIpY|Ay@6k)rOG1=`kkM7;AD3w zE`1`g66GBqqV!Q(RQE@#Lx#U41_p+L-Y$dfDobslm88-S4XY^!|80ys{r<17VoAs= z9AA`3-M)y|T*G70(;e?N7~3R#fidst3GD)%^%aV-F$mAp)Ma=ZPRMhRSgxF=_GT{r zP*DWSOJ$Z&XAu@!Cmy?hYbn%1U#bh0-v>E}Z^bIq+iCh~@2#LNN%nF+W;6K(}CwTE(E;p)=jf>`P`ZfCCJ5=cHMH->ke;|bv*6YpKhN|uL(Ok zj@Oo72R@lR+JMk2yq=>7LR5#fDQhCXSPn+%k+qN0uhVvkhzQDHd>fXT6+NFRyzb6N zA#8s4mP^`i1TqKjCt5ESoOg4C0|Z@1hG7YvVp1Aya*;G2SSfv+hCrl~sLLxhk6_M+ zf$(I&LeJNdjrrquXujv^x{=SXKV@__gs?wqOH*xr9}k zGVWt#;32Nxnp~n5ZWSn~DuI5TaxbYSqRo%j`@aCm>Oth?Vuqk;B5u8Gm8k;Zr`Z=| zS{?kFF!K}DrD9o(DsA$|KRfjf&{nlg>IItj=WVXF=z;Cn?EbHd4aeV=<9>~ zkUp{Z&EO*P2YcoYgOaL1Z0a-gx|CHC7TZ z(XJ(QwBJ>8Ktk$%HBe_uQv0CV2?I*By9|9!qXq626co-?>DGf2-FKcw5gFFeO5B5* zmWoSMpN9|6Np>^TOdNcF|3>g-#!x~$MIle4X(xg0W$Ja_saR)o`A2>TOe9*j#&|Aq z1tB||IsBWw*o@qM?~^BMLE&|r@{haGL!~+NAFR!OU=D7z|4oqj68zMWmn*$3KW;6m3QDA0n@kVRLSC>Q zueR71U?t*EtdWaB742P7;L7yq^&%`Qc??mD`@sNwRUsN7Mldi{8Ya_p6ybOckp#KN zc9*#`umn80BQRv31bmxD)Na4|6%(2p*Tg80g0JR;9OLl4nwE@jL5JhfXNbUJ)btuclapreyh9v3vgD}vHKe&LEyG~0=&sh zkst|q>ek$>ykT_rrtW>~=Ih%NyXB%y-mld@j{T78?Dj2IE06@M8uqT&##(ddL-xBq ztli!4a-AEioCHOWqX`#Jv;Y`RMjHAqmUZ&w1*>Bz-{axaOwxB_SJ#2;FU{JMS$q=D zS`)AC_iyL-9fBUtvmZv9cZ^o*06W9erJX2pwhnwTjjqnt^A-Sx@#yvsx?yfqY1&~3 z2u8}{GY6*#O;OMsvVk`zV*HwbX z_qX48W@(~o+^lOW6<)z3$*mde?ZAR1+UsQ}ID)WU+uVtQ-v5#{PC+Bwz4hG-<_euZ zs?BrE`E#!D*Sw(TDm^+0YV|^l-hL7+MwvZJE1T3B4Dq&?!@P^t#{1n8rR-GZouGqk zQ+R03`^&$RNOgw#tFP`W)qqFcD9~_G{qI9BO4Eg0&T2ANoE_G0Z{^T)-S1|fIm`m9 zT>!&NNy!s$Wid5|o619?mVD922#1-g;5NZ%|$HBhl1& zyC;CattF3*0l=B%35kyIaT|(aJgQ6Z7=6BGOkKO_MUtIcb?(q9wrz*e9tMNE^`42d zM$O~vYD8tE93SS|X)49ALa-G2!E#8~sXa2D((UQA|bE7|UFYcvw6i$xi0XMgaC| zW@Q`8bHgg_nP9k9o$D1Cpe775`na8blE^t&Y;vDodIBV*MD7oN1p#K*6EYqfgWNH< zzSq*kS1L zxzXroCNiC)!SZJg!SGEc7B1nA-P+aQ`_(XI`fBW2O}jULHr25LR{9~T-~+KPlHlFL z&RR`%VaWrP8~o~ap^WPo3gY(LlVdiX*QIbGdnvtff?5|9%R8EmZd*!lWZr?R*E@LB zw&RGxLc~WkhDB+YJUq$HVo>;?_{tR@o-Adn6=IG*M2gn(vOFzOnVNvpWFDsWw7b{J*EDG4uvJW=<_dd-@Zl z=w=u}^-;&Pw!v!Bh?MncxIBI%Q!PXc;~n=wl5Kha5IlK0fzKYTZZp7r=MI$^{)Lz+ zvEq|4Ryb~D^C-$D1>QF6b8tJH8vu*IZho7Oiwc*dr zJ{cBBHhvL++vOf+llgIuei9>-_yv12TWY3YjcS`0tgJZnOJc>r&qx~c;M-*J4vSM; zvix zQyLHOJEU9PEny`h*|6VrW-3y`F~TS$m+F#M`uFUq&~_!`{{%8^+a_T!%7)8eyyr)(EK^8Q zpWDx*Jx{+*vRsc#@NC+qANSPtS6dS4x3b*9)T`lS$WKl;0ZYY!wrur-STdD871b)}0WZ)mfNhCmp22&J-_y=cw9kddW+t;259)A|j?qDdHuvV1 zl^@!azixxk@P`a%zK0w$e^$o?KEG?5f*gmXD{z_P?w8YgB%he!?_|PV{mSgeA&P z`-n=hIRaiq_1Xo^3clzIP-f6N$Y3l|j5N$57y|@fel}yEy8fH*f#b#l`la@IrO6Dt zI&^%lL`fa|ZE$eLfam??+MIf3Y@0d!r>+;fW{w@r@Q)Fn*F_)WbgT*L=5KsE2qadQdBv!$NIiasOfaS)V6SLh(d83FpY|Mg+_3)0k!dYH8L!9^a z=F9lz6ynYs%golzlT!B695Uxxm)GE0B)j*EtLdlgZ@cp3&L1fCn~l>wYM3$|C%9!Z z{JW$)qQ24pw2@8yM#fp-XgskkefP8ajDwtxW+`oPM}{$)U;(A9 zy2J_Lo0*6~-ky8&tb7iOh8zPG&+8Or(os*Z|4F8yo_-*9i`i1I89`8ONnQKy3^>N^ zFY*zYvr_Nv z$D6*Abb##<{-!0YL6Odw(J7$oXM0mB57< z)TZtM`=t|Lq5q%B%{BmGDmyQKu#jS%$`l^2s96if^F(uV`%xg{%>JFxnlx$g%DSQM z(agktlYyT!>qzuHB)Iv~^$0L*X0cvTIR%YNF}bYFX#0m*TkCBxM~QvxVx4gNh>kZz zY*%SBo|q@z4~>a<$The}Xx|d7r9AaLUsF16jS&DJ3qB(WBSvK5^7ERIG+NDy8-MzE zLPU*72}+6&t@Vys%~OXQ=?e&CrxL}AP)z&i+7k#r7PNB<&}JRBc3thEd*JycBA*=F z3g|~26g9#;8^<`dl@+4t#Pi+Rz+}>NU+{>gQwyu+GAwDN1F?RxX_-&;?!~ICQBa&H zq{844gQvGCT-+0GWVzSaKwc&lk+}x&l+@j0kcp^x8Dx{cpmDA3zm+NB{wqX)IroFV z-$nd`v~RY(w}as8^5j6n%>K3#BiEax&W+j2(@lw5*=Z4EDPX;*cy z*h&x42dq7*j}Efy9(TLXJwDF|K1b_~jCig1UN;Z6o;xQm(4ejco4v`kxc+bnB%*WI ztHuCS2;FaBZGz|DK*kI_ETLT&K;Myt#n^oTyk?s3H-LB=p2=VKk}o8)yRFrLp3hg1 zHt>AC=%O?+5z{||^0FwY6Y1qvH`pw@&pqUu&%a)f9IiJ;qim{G8!Y45g4CZEPLp-jA66{$4y1%B6*3++?mMCh0*T!=1MeNa zjWiZu0?5ayc&~*7z!VjWG8MC1MQFFR$7eGdQj{Qy%u;}(M5@)v=wHnHfs-UUW1i`G zT+;p^Qm5Uz)8g!QIxXQg=S$Y?w9JwvLrlirCKFN`m0DHG2SH|?J+BVovj5_N&e7?pb$EiIZVRQx$E31^WoL(4(LHb6CTcx1xA0n20J zl24A%W}KbX5+ca&Vi$k_tg=v~;^*A>zA^x&$l+osCHnnv1xZmGjbImRck zEtYAZL3IzzAbsKTjrDR!kVcf0zVT@e&-h5_amxH!Wk1V_NXC5R0592>>C#z#RHSY3 z?;$KMhvGZ+Pq%Z!r9)0r%$Lx8%Dm`vkFCF=Y&#XVf?80E_aO(pz{eXXFOdve-jw#} zc${+^kCe2c;?czCp}O8M53)6quGa53pC-~jPpi{(u6rkzT4}W?{SwYs2e~6?s~Ju6oYY$jX(VZ(R+_J|z3R9f+a| zf-I96&@+K$FptUcQBmm6>?d$5yROK0ypnt-I`7&u_6YqdbaD!)oBzxZ!f;XVXl>X{+VP4uQ9D%*Naz? zI;hp-TEcm|Kq|vsDimeO=N%(T@a;;1=$HIE8-jn5q86y6v`yw zNbs!>qc3yduqs#7nA5St%k!uZ>&)LzZyy!`9%L@y7g8jP&u`&0 z8^B2Gfo~5cZkI$sYwTQt&-<&w{zw3ov*MlUB{$v--$^)|=H?tiYVK&H@O&A}U06Uc zOuu^bRq*vl!7+LcWtfuEY&mMIGkNC!zHnd5!k1>9;$ zQI<5GJWqf#J;}+mn$%zmghsFvra7Pe7n%2{dOtRDZmeCcj9^R$7Ty@H;6AV z))_NXT4XA(r%#>BV?s8KP3j!^fbRzpWrHSr~@7 zD);(+=7$I^#?6Nei5#l^8p3+eg6-)lu)6)c1xGfSrgXWxvy;kFZ|wEe+z~w9BB>GM z(8jbJjoiyR!4G(A|%o0g@@Km<`FbJ!2_c81c$jvJt2sIf`s|RN)#q(eO`BV)L4<%g;lH$YAWq;Je8{e;J_#;#64aCEp4rJpV#xF&VDb zhZ@mLTLW&$$cLLg4>}%4n-tFZ6p-g=2S^BHKFzzz79K?-Lw{q5B85M|R8@R;kJU?% zJNMmBa@5~pq$ZFuUNStK^19p53~uT)&AyMoZ{9rpW1Fv1$-P8cn5vdCq?+FWw}v>$ z{o8`ZG`F4mm(}Fc)Vlgyb3Fm2Rqer8$9USIvfmDhjwE#e#R~{=zricpk#aiWY$$VAD8c;y z9u+8M5U*UT`bhi2rh^>6>3*YpCJ@!hc$kUbmR+gX+2Qf0Na)v0{gb6VM?s+>-bP0! zL19BDx+XPDF+Al=lBk;uTly_fFN?|UaXg$_vS>zL3*Q_C?671sWX@TcGfZRtsHj#Z zHENVNFq<21KBH4mCOP4LwZr5%m97yY3UCj*?dMeU9f#}=R>rQ3;dG$%MGSN_liI!a$9NSoQDGfP(^vdM5h0A#`UMQY|sNPZQr zcL;BPdUxfvLbA+{)9W$8F$nyU9~$;U>4(<|ymya0NbgM^jKqGhb)nEP@OL<(2PsZ> zn!m;RIXr^mvQ0|0?)I>o^M@<7pwHlQJ)Z#j>q@<)Q*L{F1FseQmJ!2l-rQJ#k~6&K z0MFHze&_dR)iAkDAgJjs5`RdXr+o4j ztguj98{wHcv2FT`%DKsByYYgY&J~gP;hQi`hFms}WC7T-E1^m~Uo4C{4JIrtV6uez zkwQ0LVvv(vONlKhA%X2Q0+f-ROP|anB(Qfg%VD)(P+NVy0)jy#GVl@3LeyJ@&Gf_{_bU{wo?B!pe11~3T7vcs1K zO0CUSm$q1>|Jn(_$d0h!{v0v zVHH71v(uh5Q<}#GH-Ul{krl^AseVa0KnRM*sv0)o%kbF&qdUG;J~%!QmdBfpc@76*(JQ2q558T^7&|iHgoD7JrW-m_!$s1?L>0c5VhQ~ zk)QG3uONb~-p(g`EXaJC%4g*AjqyRD*g+DKo@yv=GBKC$0;brF?F|i0;)Z_7Kfh+^ zq5FP`vVNz<&NaP1Cwdx^IJassPKTw<^aKwBAc$t6bjg3P;Qv8F{=1^Kg9vtIq~SfT zT;GEKoLLI=-m$m4lT+mV3>A*3KQ=BUE|+3Qp%@B9Pq{evg@j9QE4tziMClYv5K_IE z(h%YylKJYO*fzVq%5Zb1wXi#gIsUtGL!;WDnJ(rtG9t|4bh)Kl^}ZKpWqqUD^`Bq# zt(2VCRd^evoQ=_&o+%nBYwwUUzoU}FEpgMuZG}B*(M6Wb4t^c2%y*A_+ArJ!d6zGa zlTtF_Uv?D@2j#OIj=i5F@!8BSX=`Rm!sxy%>(?Jp)CmBV?9y4ixZIx2ak5WLXn`?q~W@2;D#_vWTC zXT~rnOOemB%a51)={D|Lxv<70luyhO_{?rIyNa*`Y_9htYua!5?@x?`h-UB@HGN-4 z06gci^`t_x)g;mvmFF@CP)ZRB?86bpuN0Y1(pI#guY8P7ODVA{{lk&_z25>+zn$#h zH*9(H=ObCrJiNXfi!d7s3_rnesM4RY6anq~x@B(OxJccv`qA#x$^UIg)cpulAPL(| zvxjQ)XQ+1m?c)EfOS^bWplVGgO_b;eYk58ULKHGwbn>u{bUFW=ptY4cX?-Le3q?E~ zo5E%GTSV0UNR%|xJ=``mg&OF=sQOQI24yj0@NI}gXtc3&6)3j_rrH| zLNKq{kRBmr&$X8FaDeZb%`{!OA(!bnJ~VNaRXO)+QhJl3zY#%UkHM@Wgnx9bp(!&B zV#@EV+6@$lv*v>=ADuKa&}=pUwVwvHuXg#-q3sKV0_xPM;BpNH6v#Ue z?Ge?EkqS85(#_;l@Q6Jfy>AW|5d7PYvBUv8QE$0fLnM>D@H`#wRycZ&Pnxa7E|9KF zJdA@E%J2fFBp?o51bOth(eV{VJbu+^_ogOVsx& z3}>>&SB^RO|J~>OA7&A(vah8$H>wBhSYgKh?0NpZeu{*-#c{^ybtJ?0G3J-u0bo@% zy)($GURY@|t6F?BqQpQ);DAFRE9{9!RW=chfAe*D-|)b|IQ)}DnIbo3h5xMox9A|p z8z=Z_GbqN+o0(gH4N4MBSYz4g{c$_i6C~4ehUznw!xtUL-x}J(B4y;JL!#Nu2gKU} zh?HL^;{K!$FrywRmSXiZBklPLXLIAep7E{d1J%}`m+GHVrraoy~f+n zc(fZ>R;P;lpE9wY50Nf|&l@`*$8DG2t)DSL-k`wWKgKkxw)~ziOc8PQ>g@j!WdX=q z5YS=0uqc|W_*-$a*sR5FCAU{cR8-ip>|O!yZKtz&Rl(AJZ*uy_JAvJ;`_txQ0dlTKVl&ie%8`=*FIoYi|HG>z74!% zg>S)DOeCVzrz~M_MtoeU+ldTC#E*TzqpYG_t8oITv z1zBSDhwMF~xWbEJ#05;m+-*cL7j+1TL_YjvgCAgO!j1x@NsL^SuY zH`hth$?U2WiDeO#a43($!7%21Xd+jbi+%Rf&)zIO-U|F(KC<5t!nWn=H|NiCzF)_4 z8}04@+NrcAKPsYel_$aNJeJaazYIGF_|&1`a!IblrZ-x-U3SC1{{@z7ABY~hfPif? z4IrN;ldQ0=fd#~G^wW~rdb#Zgpz$_uv1SuCZ+Zt4I6pjZ`s9h5jC4~SYo;0&^YMMJ z@#*yD^FiaS&3WW9ov>0xxJCLBEZ310Z>j4o2azp2HC?!bLd$jlh`%Fe%-OnR2{gz)r{HENw*g#0}V7%`#48 zEax75T3r+DFrOW`#p6)!VBtK`L?z+?MFb0C7-U`#)_CRHjrNx}%q6;OU3v6`F; z`?PsIFV`Nt+(-M24Z?fv<#5&~s&B85scwi-i#BEDFY`B+wrT|;W5uSY;bMY$t(Ibmp5wuDfWpRhDdldC2F(?%%e^I z{{{O${@1@s-Bk|ZA-x7@k$@**@9Fj=_@d)kH}k%v&w8bz7H8E`KHo-Lk>UGZ7$@Zq z0rUorl-T_X%1!s_^zg0wQH?aGI2A`}&Bh9~W-t(#=dslddxQrj(H}jcr%xyDG~FGd zF3K9&r?ER{wu&Z==5bgEPt>&0-zs4NtWH-0Q2B)NBid~iDm5utSEAaeVbU{Bxvj7* zPsqNGU!~hL7Y_2@dp?Y@T3SkQnCe`1u0P#xI#&L@-XHIc)>+ba*&K&bNRpa~(UyW& zt=J;eFEk=D(}oCfNEiN8kMN^Vqn*}!CN5m^4o7;S87G@cIw*gz@{tB&r2-}MAYx-t1FVo{M2Wg0HMNGg#ji^QbCkuL!tirhq(b%;Uk z@ugpjZ=@1YTAkhm_kG}u4i(y;7L4F6#e9aRmjj7NT=EyuU}XODB}120>qe8ea!qsL z8I$b+Yg{8hI=!X1GiYg+K9kjZr?&${D1~p^8KIs4EiI%{`>!iw+g~}T<_pRWOC6Xp zaBs1HVFex_zrg)Gy=mKUx4g^Xo&;(?kmS?`?7<&+UCcYmu_3~#k}I{A$ChrgdD&-M zK}(srw>L(~79PI4*npv_0CTvXDE9aOQ3I><=6;aK19a>~qUBaO zh4wUTxB3Pg`p@Mu41eXV2|K8F0d?*D)gZy3%_0jhzoZM=0v?&HrW0B~93TjYQ^J!o zoR{{dERXwJ@#@UK<@*81PysE-;c+_ba9f|-thaW)2Vx9H5>PJZ9fV@s#JW{GqLdpA z%f%TrW~k7wtL37%x6!&;`zsg>$6E8_9FLBaI(;m3xCzyt>Q`Jz2tYb4v6d&)n~ic#TH@%1zNfPq%1(aLBO=HjC4}qoA7&uq?J|`#bRSV8VU` zldj~c6m3L1@{CHc2xYZ&E;iCd_CQAD=6wC6wF~r zP2cx)Yo^ChE1Eb$3_tIm7m`HRzm=i|4RIJ3)86|dC>EBq|y@*G_y^|(D3QEd>k z<#qRb0onS$*&vhz2a3l!BY8D_dsI$XsrMJ#$Mf=ie@N+Q~gn? z_*FSDNWBM?Ltn(01Z0@I|C0s4=epA-;(7coq(s9lOXP~L0O%8|*yEja)l#+lc*(u- zbbRvkh^F%xnAyx7WL-+OFJRUk0l<36I{4?_(cNvb8@4!fLQyK?EH20`(J^mnv}n*P z^F{m>y8g4!{0!bCl$hZvW5!IkX^J zt;@|;c-49;rAa`t;@!Pz4WpkfQ zN~MkzPNPcDY*cL!_CU?i@TV;{KsS+t-te0{_${$USZh7pW?`~;y&tvF@Ak^Pb7EMp zzTN$-$CpcbtcwZ=G>3kA+Pp=2|NBs9S+~Vux5m^#YBJoQ^6oO;s39sB$_Qu@05uHNIOuDGs~x=JLe?c?vtkd4P< z2%DvcR?f#5eo_)ImyL75lz}NGTg*TEKU06{M5rc!3Y6wPyXo3wsb?vkS7>99^%1vO zesmqlAo|cc@uLX~N6uw+w;th^cdB%XXl*7OCG74<&_rzuX}r;PZcmQN?=q9v_#Ti9 z?vEC~6AZYxdb$gUXZXZ?^v&1gTGrQfsOolB-`U}9X~U$KrDq1En2*Q)(m*LRgdt`r zP2y=rX8|LeZ29i$M8e_EF93Mh&7H0h8~fM-+5D)KA4c9UuZVoJ zvEOA4xw~YPKH6VXf0zAo|1vVCM0aeh>tTJIjhxHm%P-W!uz{ASVNt=0^D$31zDl2; zjmF0MpMZ^vU{&pCW%oft6ob}l#YO9`-}AwY>RXQoRR+5*=lD*<%(s7ZYS4Lz^_BcS z{@Gqp^VMB!a%YR-Fv1n$K3s0J|C0|ke>S^YdjNC5rR>Tph0Ij|&)ko(CIhM3a7}Y2Se*C>VD$y6ke0oP?Gkmp^b7Sb)*!Sx; zf%mBrHJzs}bKG*}v1#|~)cs+s-=S;nlkXU%SpTQvZD5%G6i6~KYroH1{;=X_zJQ5p ze~Se6X@w`$Twos^0@ZKWLPI1H2OFn%U+u>)pM6oO&#V=7YUTdCzyG=TKEhYZ#&U`w z6%3Y3;BP`gdv2z{*k-an78G31QmD8{dwr6?d+WA{D#Pp|UbHAeA}l4g|sMcYp6lg~+%@ zsmOV2^(zElWqx~3I$8R5MaSy<i63jphyiP7lfg*eGFak-^s50lav=ugOACOGg!I1xgny;gzkNXjdY+SiHOy5Rc&Tag`INojQ@Toc6bBFcj3%w{uk^3i%s-d>pFUXTCem&7sy!?jQOQ_w zXi+h!W;!B&7?Nqj?Jjz%#H!ns(Icv+M~4`PcGwEwFhOzo*BZX*PMHkMp4$+r@8i*^ zSSr{hGLU+El1xu$L2RDeOd?qs+hk?=x%Bn4=p_3Yul>1_gj4G#(+a`tAYvKwN}aaR zElj6&byr%&=syUw?9hn<2kcMdXdJ+oE<$kN)H7=WA5?GZw?qZ;a@t3mpPe4p%~P$bPq>Gv?DxLH;X zqwA`fc#e7Q0WJjej7klQWymSm#AW+4*jdZ-z&;srhCw-0ejx_rex6f`Wzb;Jm%hzC zHeWQ-{0wur%Se6kZ@ta`d`~A4&6l%vOv*P+S`jZEt9-Gg^8A7E?yhl5sagiZ4qW&J zEKv-9)>n+$+BM!mqZzdo+zv6WZB|uNFkcM}D$Vx4J#G7qfAI15uP^prKKIKcMpFz9 zTVyR^;soah5Ar!qu=LH_p&zMXfUWlU94T}9H9=3uQN~%XZ6}fbtUPVzKJs}k`~JRI zt4)4D!(9s4^zTy^vRZl}hfOLn{XTjlgi~aC7eZk^RsQ`Q1Wa7Tbbasrq7AQkZ8Nms zu5t3VV%1z9Y(CYLKk+l43iRx)Z?=%-VD#RUb<=CfznC-eIu3qI(K?O_!(P^Z)@c9z zZLsR7hM)%;!0qpnix%2R%sreeK7ak@BTJd)z{(0H6xsl((I1MBm75(Y)W2G^7vn>XK!!~3 ze?=BQt030nbykLdX-~hOW#Iip@C!FslA{A7KRrr(`~B6dOad{^3D0HM4XC3od~Ckf z+?%nh<;(J7gEQS-hn4`w{(OA6(q#9Vp-1pU*2T?31bkpv)1;NyZ$V#gMVU3U+(f}w zm|$CU+h(8s?Kja-OsIMZ~dw$=*Ncb@1 zvAO?c_SL&qDQioVoj3ZL_Cx&PDP4!;L%L13>3F`d>KxQ8BE zqKSTptcY#i!0iWxt20U+uSJ;1sUkxUZpSTDv-(+{)zu$@e!GYchKN}8<1MI@duBbiv( z7hzrDuxd$TcabW>xf0M2O_q=^b6Kjht{SG z^;3h@_JjoNY+UP|UEwsE^+u5Tb=gDlHlMR`IK$YxdgeRmlweUT0#v z(~ed5u>K~!ah57C*`0(*UJqn|qc9V@b3c72c|`8#Sur4|U!8YY`XU4ZVU5Lv7($=9 z%dKt^a_lCQSB9(A4@M_HPpm{<=@<~f*xbxbRp^6JcFe6I3KcMOFZ=W-%3bV6~ zcVGg72s;df>j2ifAXBKuKPSP>x;Yv@~9q+8L1 zk<9amrWM?XYK`K#s2HR}oTSjgLpvlDYcq;KTCZ;^@%i?a*h$YA$B zhDmB%7=g4Em>^t>I5h$#P0JsX^xXE0Uq=~)Ui9;No7tANtG`y9N0pk%YV@HeV9VRw zQ;RfTWN?Mzx|R^v6u@N8Zqn)FI`bqgq4ZWzIitrQ`H)kU_L9)%$0xMocI&y%zhFKq zk-&88ubSG)REK+|{Z-=p9Qhv@U@^cF-ZPBs+XJ;_uRpfK!99_fysKpTTM7$VZDZ#>9NqO4j_x=_VuG zTKcIM{IzM{3adpl(N3^u3{jq?3l}jV<|Ah1LBa*n?|tEuh)!v)@9aAIm&k~Qgvm8# zq&B&0y0V0)3hLQDy)80jGTX{ia`KS09B;F-WDhY$GuCt+Z*p$nCr%$Yue;6QRmWee zOi$rx0w6I%7ihp1l-F^QuPI?A=&1j<0XP#X%)hSx{x`IE%2>a~c4`}z-QsbuT)M$& z*g%deB;Y3QF*<$*U#o#IDAQ;Cc6X`o%Ag39E{SCSI@3{todI#kQ2Y!Xq1i5OzPQH5W{+hEdsYtahDTQ3$5Y@|`JM8p~f}l_9 zyi+#p(e&OoeY9!5c;wP0iEi(*@;At^_f6cIoi}=u*-}0Cr^~I!mS5+;yV5QLjaG-b z#}3#orlsy%W?(b;;5ZECWz|M9f5lv<#pZrh>=?0X0erh?F0Q#sGC@-iupuFm*CFOI zss3L^tTG`C&w3clRT?$fJ}myGSd3lxFy{9Mv?=s=KKn{1@MlB7)7{EnX%hg~nBlu< zoBDJb?s8Fb@pa_6R()Z}I2GQ_*RmiLQofz*Sl#RDSg!8Yq$V0@q9d0Lj*N}`dl?0n z{SsGo$pWFSf_$tsGNKAf%D&Y63rv6f4j~vURf7G~A5tEE86=_}>yimw5{2{gSQ)Mx zYM}^aroCn%aGxN_bf-`kJmk?gndYd1%eZVoz04w`Kfyq~3D_+@ZfDrHUmSfH%N0EN zo4){3hNHa++itNp4Yya?{Vww3b4qZ{%YV*VL!6@N^TWUM z+eDINnXtc8rW#)2A$B1YaDPFmUJ{kZnGsv9uGtmpmSg8$LhU2BaJqDk!F!>gY`v2n zmF>RWIxWM*oY7sx&V!(7^@8UDfiWxHi|G>uYck~4g|V`%FgL(4$d^)rprY1`%^>>K z-B;oE(Q#|Mm=5XDCx^eWg$I(@o%XTl>4lbZwv-dJ)`{DX3vK9Kzxw>NjF)u!=c*PP z@pf4i9O6lD_)I!v%kGg(GIggJ zsP?(lSF>HVP9sOxjmdlKgFNnF%1d}}g5I?zo}UBeUS5^y$QTpV0?7M^UYSj#L-1?p z+;A|-Kqsv^*H$ee4084oD8MDM%BT^UEd9h)2Zx^aE zqm+YQsIe}Wx^!SuP+ncdBjWB8D6Yj$^j=D(H9;{{$F(~cUrrfB_YzsHBo&sk#~sea zJ}l@`e>KmgPM?AKx^jG+t28qMFwFbYjn-S0wU#OjJf&fs7xKm|CSJapsxz0vkYCy3 zaAEeG^Y;S3{z} z32)5KZ_HURNbI5G7>?n{E@|20xK*(_O=$|3xUoKZV*TeQP;ExW>#L%$+I3NSeaBPU z3fXvWR*BCo*HY2+5lQ#vfHBgGkJ6RqWJCHD(m!rkdP*4G335mnNs58yg=DIRs-xKt zOPA|QGL~{Yy2ToLt|*7N0&+1_0YEqD^<4OQc-r^QP1#tAKkCf}$`W+u+T(Rlm1>(gY8PV81I^5#P zGrGhxXvvVO^cQQ9Gz^?4=Ngh>Knbk#y5uEn`f0HG*hZ^&Y7pD6wXRROz|(7b2vgO@ ze$ZK`6~#ag!FH_(DA0{m2Z;GCHUOkZ|G^{{Q9E#^6L>_yZz!dz66P?5VNyC8s!F;x z;p(%+K9~m+O|H2Z`%&#%5U7mSdIGfyALdJWYlbPKqB4b}EZj+_$MCG(He74$EnPkh z*yu6;Z8Ne;%sNKGTwdkc^(MyRaim#^r{A*G~!9XvWJ6sq!w1X$D`hL@>}J& zz5cQIV9BpwE6`A%LUm65@Ru-FBFq*rT!#-_?m_d43R?l}<-2_>cBhkM@Y8X=b>T}& z2Jn>CGC5oC`8y3JI*&~*bV`tWPah?>Z_0MR^EJnW|7O3~9VR%1aPp1*^JiBZ``+&J z^7L`fq?=pK-|_JHO1Ct+!~0wUo=Q`r$0f3!4c@tua=_e|W&YK-2hYJ1Iyjs%VEWU? zv0tzerg%G=DODbVOd74 z4YHr^mp#X&xSTN7G!9?DK*{#!`jW<9@;o6gX>k#srcWN~BW58Wt!9g0t9PzoUI{tH z*Yc~HgvOgQUO`$VEV;~L;>)UuriWOA>vF+(vxGR;T&QClxjRHBD_8G=-+1rs(Jx8D~x}=~t5`X~~Xe{OacVHcj?FEOU-d zpVm2hOzMo_mn~@5R(n-}J?M`ssmsRla8rV`HOAIoc+84y+Q=g6Hr(>HRN#E_04Tku z$co>ecKuat9lMe^;mswLFqO)LbbUW<$~T`)TfWsNef@p=l8SMYL_uxk-e598Al{*x zSs=mp+%M~%fuX{!40se7)JRy;pb^m39`~8^>@!Npjs|av@nS6O3Gi~ZK&|ioKtAP))BrI`A50ZO znH{gmX>H@(EkjYo!bprl{G$(=a$ASkiCM#U&LnZ8U}?JJ^AyEh!ReDiyf_AkMIOfzqi&1Y{AM~4EY>b8 zJQp2Fkqr+Kp(}kiI({AzC8DpX%!&D#z_{PgH3c;I5fUctm&7IKgot9RE>c6j4JQ#Y zh!l8{5uve$+$)GOwB9udO8*c&_$t(f-r#`Y^`0(w1rRvS*$U|I*_fWlJAuZB10 zfW}S^zr8m0g3oPl1KmJueHXm6v;&{SI^in!J!^0-8OTUgIv_2#>Lj!LY-&7@RhKmh zW)#k<>pLa=CA0j87G5iEm5HUw%nl(cR*mpAkG$pYs;VLk+2H|Xxhn`wn#Cf>&9+%1 z7HT=F%)U#e#roBJ1iR?i7NN-pJ$q@g7CK9g+CyDNZczcR+&3aphKJlNXcvD%xaZ?5!<;29E}YB9e7V9mabyMOB0{bPv+3$- zUI~YGB0AuIYWU zhU_^NUl~IFrgZk@Q)K-Q59=2?!;F)DK75W&DjeT34k`S!_!b-tH(^l&+REW3XOGH6 zCD2uO8x%HsxOgd-Zs15~sw8Xu2S)}KE;Cg|A$T=rC;ilHo>Et`4`jHqFoiy2jNi`oEU3HbvLgww!h6Y(J2kV0nak&>x_*fLH~L0KNQ z|9hM1jJ%OC0m%e`8(SMSB6cK2xiP1GQfu<1hBq!i11}aaXTx<*xzt zuG)bU<`rFWSbNT82qx-azf!m=dM)Ax1LK5HjTG3FfO6=hUpT9Ivje$Mv)^+-x;HKO z7A`F69+#hsu?Z>ntrnO<-dUdC-NPj0fYp=}gN+huFk?mnO3kJs1qIk_T;V$9?n;f> z$YwrhC0+Ls`x;uz83|j}eRJpZo44SOH>BPQgt=;!yEF6|(%KM}AB?Y4R^ltppLH;x zF_L6UwJS05^pa2Oj4YG~V?QWUS1WFPY)oeq`qL}>RwvKKIm9t@PP={cy+b?0&u=Y= zKT~9JK`xJgHnb7P{?q)Aj#t_bzxD|}5w&How>OnzWS;8!%Z<|XF+3Ij|1BpQ`rlV$ zU2#MoxAETaoKC-5%4H&n-CL|Zg^6a43K4a%j0)o)D+wvFj-@0hRVMxV^a!y0m&O3K zjV5F2g?3Rj`MZ-Rs&{FX;JZ4K1UX&}J&ApIO_qY2li&LL!W{ zHt_3F7x=Uvq@sr_War~XMHg* zQLXg!eG5NHQ(>sP&Voon@IC}Dkhj5z6^(mOWHlkJndeIw-f+qrxK2LJgeew&K{Vve zhRCxmb;)=H0ZLoSM~iYzmL#J7sz8Ub2LQN#1*nyLzrXIgy*!R7ZF8795X`OJ?;Af; zS*0k#g~aU+=0S6bK$W;1bI24uI?98qclt5O=2_a+4cd|a;ikYT@WJOgcbb25)1-R` z1Glr~dmNfHjv%dSDR4p8OL}S2iojm?XM=@MT0a4UCz99}I%HwLV)t_6&xyvrvN06R zbcxj7!X;f8rtu_6X7h}G1tW4Q+9OC7qfqadnBC(hV<&SlUn}Mo*`Xv!`x1Dem4k`M z!;p#OD{)sFct=S9Ap!shto`}jz&D9h?V0Hr*ygmFQ3fs8!rLn_>hrTo@$Yt(QRljn@!I8eF$Ib*=&j{gy7;Wedsqo1?Kj2- z`c(S{K@(QGrR1W__y9H`4{}f_kDUJT2ZL$dYhjm~(Nj@inzPP_Pnd?sluOhh9X6|E zsD;P23G4Nt`RNS+q7$FjP*2jJZP;MqJScaImFT!UQ^A-eGF-4D03{&NpHGsL&<=eo zv|z5sO`%B5Gg8n(|B=OqTq%$2SB?=gsQ8{&=b3h{G5@TLcV(V@h1P7uL1DrVs;p>xpG# z;iYq>Q8X#b9Tao2Mi9BknYaDrJ3J7@;WPpX9to%Mytl?)`FQ691rqeECUv>)avxNn zgfWGG2vFlX=9yuyO49g8_D-+6NUCU~&UaRi+$z~62^c0T&qGNvyJPf{zuWl{MpC=J zZKhXFVUkPH#l`oRi5m^1aM%#iyzX;Vm37YEJ(~WEiLvXpIHwo5PO_muqg)A0$xvVlYpf?WqTd@&$nOzt$Vj903p69w6a!~|KUcyP zk=xP66jGd(B^9Nal8o0%{M({|}TBoz~2N6B^UVbRA)Ir)3K%eMvP=bFk zM6s~7xX`4Ncy!hSk!FTIEY#=(1Gb?4)0QJ!$P=MrR^t zewLz03SzSet5HxK6&TZ1OnG;j!nD10z@L&(0Txy4 zGF1A*D!E+FYF_ccnGQftG z6_Lq>k{-1lOkY6%6w-L~??=yP|KE@Pw;8B_D6Tl(p-ilUJ%(K|A$aFcsw;zS8y7<* z?j-I({1&ai1+hJCTL?>*mraO$7%c?g|vIU>tS)LojIZUxG|zZsbYEMskOJ%o)Ii!nxEqI!j*NMU`&3e z!Bkm%k&a7%#zV{AR4E#FRgr_P=;St&b#beKqy-lEJ(RDDJP~^ai^l7GpN);_E5G4q z8y=9Bt6HLpMl$U+V}-K?6S&yLq|=m(6mS*}?|a>(_SpL7w4?2>Vg8MsVkMTF;d& zU--tL6jkB!P&rT#@PwQznwT*f1ib`#9?hL>}Vk1ezg!Z5B4a z^bZ-8bQ6_<$gK-z4dua;-I-dA6(Av$Qkb`zg!_R{-;T&vn?JJS8@Tw{l?>gh86i$M zT1g#xdoSXoHd^}gCRp)1?L2m`u2kvmDc-5 zAq^VBHLF|%l`O*l{$b0f$(q7HBkZ#AMyQgG7j?1kQn<8eT zK${5pjri?Y3v#mmy8{3_y)K2t)0-=xfS!+E#7lv2t;pA|(Xr<_q8W@SE=T325aP-+ zjfX+Z-_x;F4b;(g5s=WD9uNn@8tH`VlMnhIB&Ez=zT!FIXIj#gV13X_v!esV~k_jJp^Linw_+4>&XOh!yPI z4m5%f`(om4Ectw}_3A0NSA8j@Mv&V^B>Lv^ebBIrK^oz*kPhwUFnor|2yRiIFbsT+ zI3g$IA4y|{kyaKw!l&GhSvb9+U} zqT%Tbd{?HjRxKNK!pD_ZvcmDfX)3-Kw1y1M)5;`^_pZL?`74}0lvLgwL1+m_4YpG+ zIp57&*$S3Wlv@|0XQ>EqeGFGYvPGXM%VR`Ng6ni5n$-~qqE@Ip_x9)UXKa_9>L}7i zjwyiZ&lKZ?D5`SgVj`>bKCj~976wO!)7z+AusPDAkrkI=JE~hzQDvoC<SEjoCm&U+E}^#P~Koeo&&25^+B0-g)}NTABsZvT&E~Fe1!_C<3i-Ct5)cR z1DC?fRbx%Y#Hh(|`N4vkkM~tJ5(76ARio8TON{_gai+oqfqNm6EO@3Q{IvwB1c7m( zt%(L3R2ImG?oXmhF_fO4NEMCy*i3(J)QG`U33cf3CYF3)hV_zak4$ZE_h)=ZA3Str zo#ZL2BOJt>)&nv}=l$|V>@lzH?yt>Z>1w+^%?J`7&oJKP|Xcs!xnjT&(2WDXV5Rxi?dGeNaB)^8BO-Dx`Y=95kgrk&KAavvggYDQ6K^P0uWxE-bq>vem+7)lPZ z#n*uDAfWZA;b)H(V=E`3w{Kfq_@%2iN1Jif*m2xf*vlyAgi~Hrj*WO^wz{-BK7#8a zK^I+MjZtZ9gw;>D_mIw&FF^qpAKQdU%zF1iXyFzcV(1%aNA5E7C1^EVwyVPBUP>pL zzi9I!?*(OKau?EenPz;GIh5U;N>6+OPC$5u4kV*Q@cCY>zJ(JrgZrL_Gli`fhyMwCfK z>mh_9&!1xi3zv2YZ;A(>95-J-o7tqBDOQJ4fpDHSsc&aN>rz{%yVhW{Eho&$b(xgc zp^~0!aQvlxEUnX7UX*QGU{J-zsqXI&>bgK)+hwJ z(UKG^peM;(H>-J!xir?YHx8W^YTMGzzUG|tg*VjQCZSN){6w(w2h&L8$z$6JoFHtW z7wDS9aP>)Sd|dnuzlGYDusnCVEe2bu55;@2)?s_pK2>fKzVl!T+&Z=?Y~$k59tC zhp5=PG@>Q~ZE!ADNn=*&Rapc*>+$BCztd>@OfXQH@Ve%a&3%UA;!r^!WmGAa6Kw{kr#^yzGT!L?uk%nX|5myqQf)rTT39X(X3Bx36!mGU%NzLqYfCu3c5S#L7ep z=PpqtlH!QKZgMsD`JAkgJEyQ7$_0i9W*sYKp!)~;gp``Htj`<^x6dYuc&qi%B;KND zuLst?Zl7T)HztTq$Z?=88<9Y$=+tsm?9@%LhH4JscQJ#p@uJBzXEg=Su1Fos+WEJ) ztkGjNZeEMGTj6Z3v%O$Q>EE5v=M88t92MNg|AztFY30xV(*TzLrvcY#8k*YFrHCzj zE?*Fk#3$~Us3x}QW@9F^2SHtD-~8WT0Q#88MNf-)ZjAn z|IE{A=jk^l^Pc&-X+bp1w9DKN_j>kwCXSxNR9)E6pzaExK*8Fm5!XKlPoxLI+}pR} zt;Tmo5(cpaX|>D%MD5`6?IdU|{TOm-3PYHB7zs6rU`b`^`}G6f`EFnP;h`}+%Hnfx zU*8c>{}uzVSXmu=Eci`;KtuX#%R&FUld&)XfB@;Rz=Jjp>RR%!xN|IyD#^Fby0QM` z0?G)`-^*=c4NkBDO-O$cy!I-_&kLJhe#MVZOr@^%fwX_>?F=G_R@C(aI+N`j*dX=; z*Hekqel?W3s30jgIUWX}1IAH|!v?mu-Ic}i+~UJLoJx&k%4Eq1%4yV%`m*Vn0ey{FkI zRD)T?8Ev%Dg|kj=;M}Tmue>U>P-Dt$b#>|1!&R?eaA|EURd?ka_O>8(HM0-0GC3AcjEWGVkl&bypv|knT6XD z+EB3vQUD5Lz&=YaUMY$#OQ>BhRY%OZvk+5yGVK8jSV*Zpk{U2U% z;!IEekM*N$+$!SaQ~hvzFK#3x^QuncW{rKmMd&uI@W(wKO!fj>!<6Ile8Bv;vl;g7 z+SFAVy0cC^Xu00%@IbOV&&EuZ75I#yE@k zbQ%=%*hs5%{lKq_)nPrG#KXYvqvwPrn5wr~QoogeV5!}#f}d0HSfFFt!=q6xn4!$DYSHkpS_RT*r;}5 zmfL5FT_6L5JJPr=Q~oL)%v@nHd#V-YnBf+6hZ9Q58}j3_w?p4bM^E7Cb;lb3EGPg_ z{{6m_^HDoR*5gRS{nkZz=bfYW)nfTG`dnA_jIb>olbL;f&tJ}cul&zkKLMP+xbN3d z^)Rtx-BRj|F|wi56_<+=vUDWvEO_YEN{i(XFuR+}=WGW3ib_D# z|H>BU{zQwICBxqzf6AdAvoukUE8wvaXb%Ad%7ku`5}(e^sSbYCThs3N5=fr39DKI- ze>ea}U9;V1_2ekswc0a1xb_*_Au_q5=9ILuS>Jf#c4UK_T2i>e}7d^MOFrGawBymc?i z^ga$j?jve-vI2n>(Aw3`pOPQjAZ*T|6;7Q=C1MSWY*s9LtvZ?!Mhn@%5+*s=+MtNB zvTW1x=qPymt?U)l{fi5BV&NB(@}e1`I)bG9Tp(6;lE(*!ON4hi2faSp+AzQAKZI(B zheULc7rQH`u`6#;Q(F4iHS3_*M5N(i;|P^JKNm6?8JsX!$zv&NFuHM3hf`A)lOJR? zOQMp8N>=Kp^O-WTtE?NClitc_^3F=Rqu1T!yKbjNRD=v}Y*yMbYw)+n1=kvhFHhKL z&nj7>vEMdC-O}oMvh{HZ`$WYn>Z<@Fsb$|0*<&1*)$T|rL8QxSqFxYkt=W{fe@|xl znCui6vLEoJV5EHqSd(&8-pHr^ zcQBx@Sc6uCF_RaZU|erC0*E)=jTgoqix#(748=bL7jCRM*1;Q#POFU=n9`=9lFR!GOoE z=VAcj?RYmkivpka+$Cui$RBeh2YV&$m9H5Q<Ta#acOI7On>Jm zqV402SsHcHQkEWDJN%Wr1cAFJG84EDM4s?Q{5wx~N#U~>E>@+S36Z!5FED%D5!Wh? z)29eg(byqG0)3m(m5yZ!FWpm1W8dqBSI6ob(C}lx$TR1W)gq`q`VNH-q1?Ac{))aO!B% znu`UlUNf`&X52>dcj^MHR7}#szMk5FDh|7xVrSh1zdl|HnJoz&1wdZmxpuN5a z0BJuyT+Q**MmjFO8CEo7Mi8xE4s%H8w|of08MYis;jq)_9FtaoQOgg+{K-n)e5v%R zH!gB)Sh7H#u{I*~68}%tms!gg%yor3u zi48RznSy8W8?F|~FbAgHG=MaJF`fI^UDs$kRZ?y6U9t-JxOSz#ltSh4T@j)b{;>9q zJs!M7ok$Hl-Vl}j0Yuus2vfp)rHBE_a!}JtW<_w~i5}6pjHDBH9!x^QeI54(;7yRB zTdA;?);Gemwf}&S2L|hIK)JJh{gr+euwL2NHS@GI${aUt z^PolLmHA6E_wkK_HR+8FBDF?D@yR(Y2KB(rF0RCo;R@P^1>5Ej1#VzGXF7|*Vx&u2 z8=u0e=lVv-naK7kqh9+h3f53Oh5rDW9D`P;Y1AG?8gHeDFsh)SwxIzT*U$PIi+uR8XbGGQ4vl5R)NjEk{pWepS(jlKb3LrvLs(F%Y-;@ zNRfmnPBs@y?~D70mVkiA z!Lpet2&!=esA3!6EY5KPjGNVtAm(6)rqOlg@8wDK#!lCym|rP@O!#Qpa*!~1%@Zn; zR?xM`6Dbpq{l(!|i~kfOD4yK2oI~A9VOzWD^?C)Dy@qghRaJEIOdG0XU7bU_)qIU< zz*F(7?2EJ5=cFPHfWD^m88^>?F5+v>Kt`WpGqTwWq~e-L05bD>fA#aPmLqP5G}d6~ z$T*^z57(ZVegXxppg@~0MZ+vel~3FV0aCFBtQ>-mA}@~3tI+>C%z{CbY|-ia@1rj( zqXjN6pv?Tq$ zUEy6)Ct6uKBy>L71*)CxjvLv-_el+w#C1~kL(%PfYkRLYwdN>Ie58`X)-R^mJ-^Lw zl8KLH^ZE;mvzqB{(xx6<*aO|_sID~_Y}931_BhbQNM5ZUX+{V)P~}bBYboQ(kAIFgVB$W*&9yJ- zQ&U}RnNDU|U6fDgR&YR1(4}<#t^K91yfa1(HWPk-mWDQtE*UjtMva?Kk42Anmx`5u z;{aTO{M$S=ep(~WuobDW3b{YWcPC${zAK0{c%Be8a9x$%k1Dih%&NTKn`Su}%-~O1 zMX*m`ou>}5;I<0Kr2s~chT|;}r5bH7UV;q1r6EJHl|ke>YD>kG$U=&XQLQ2#^Ph{l zlk-?3hzE5cj&MX&1(rn}J9QdODQy&<`C@Z9=LuirJn#!$;h|`eql65szG9wGLlIO` zcA5%Lg-?UJMZ9u7N{`v1?f~$kO%5Qiy|$pmd^66%H8UvauEN_G7&1J^$~WsAq-)v~ zT?e1jCwfUJpg!+$RD`Rn=31YKBhmz)d{tdtP7_3qpWZwzx%}{kyDfIXMo!+Rg%<5Idp4Q2h$mF>F08ha zM@y(Ylaz}S7QPSo-Hy}>->@NzD@$Nbc!MK-(&kySH`3f|vVrd@>Kvtc$h5M)Vwiw=4i@LTP&Wp+&hv&a#?&3= z(*jgkB1zn0sfPJP3MQ=$@bMGG6g`d`6Z@S|%JFz6V>%ZM z6PU|?6`5e0nHt~=yVp{MNl8(nI&lgdLOFtlO}a_4$)!*gAXzJ4O8orgt9W!bnhB%D zm7hfiy=j@q>xfWcOcod_iS|k5j1~JNZ`@wZE!TLqB90`)aS8mOW!39shEf39Zf>1s zrj?vOQn=g={%G6vxK>MaqLNo+QWJOgI=z5lvSx%*e6Q2k2b2SnY}A5lb91xX-$!=< z6_{K&qXigg1P0O--uRoEo-en_#ZpAQwORhAQ~j#_9@;~goaC_F?0nohX3(VE`q1=q zePcryZ!F;5H>(i-5LaQ!x7a*aXj{&ecVm-MHOl20mRV*XeP%I{{#+&!!gY zzqucHthD{f8R*wN9TD^r($8)8vmQ?C+_Pz2Y9P~f{`&FH``1Q}Q@t_idU|CGA)(>! zN0XKSpBo6G?9{3$e_(z3q;~V@^C?%S zHUgrTLPPwF>i!l6{T{dyy6c zfEQro!W!*ZYE3@_R-cVhFE;tza{I&GhHWL%?H~Ku1{R+ipU&0o&R{e^wRJ8>?3?Eiv=tap zepD4&67%7cZ2_>E-nju2RR(h`iHtj2JFg6z9F7+Gg^mFB29+H5^NUl*E%)&L z`_|VXa&mIOwsy0nT=V1AK1#fN2sq{?hdvOX=L76p&%<#9Z1hfDHDG$Gb^}Lnna^W* z*C&}cVjD+}j6=WG3m^eBV~nNde88K{v)>sdU{nQ(n|Zlu5t{BYabC7G9(*@!bNS%H zODR^WoX+EQns+)_Xv_7n`D{C&KF_l{XM;((e_%kb$@iv(-(D@^Z>GRm|172I zXP(F4^uZ4@zWFxK6X?yw8F?^wp{$&=v~RtMq=v>MP;)dh-|)HG7q+@g3kpW#FEgg1 z`(dhx_5B9*`j@DuDM*LiI6g6%Xs-39`-~vYXCn+}h5cQGj3WV~%Kl6tS`lI}@isy% z%Y3u#ROiFp8K?b1flPHy+o|t>xZj-#D=Q8EDBxhH&L23L0x^`lEuLdJWPUhweF-3uV&bl0B2D z)ej|nTi)T~zd!bQ85pU=<5aBpkDmn3^*--EEVd`pZJ}_f;yk7Ap2sTXKH{UK*<7)2 zIuBS)Y$pWuje2urVUnnz9_)?FsKU&U{BbQx15h>h5F^G};-PahA}F`;M50@eO*fmv zATu7~bE8iD57VoY7^Kmca5<>o9*Ua_L%<0S#XI;CC+SOIhv(~zU z+Ym`EP|~4(ZCy*1LKs;2>2v+gEUqxLJ4(9 z7@nvhIw}9rY|(gKe*X2eRdLMqnR_qdLnM#I=YX#5>(qM7HNOrR?dYCua!iU3KrZHr zeyew0Iw8f|D&<&mVy-1O5cxsY;WLdcFyKThQPA64?B?_Euy~oAX$tDX60HBlqfbcY*-6YM=x{T}>rgz2Y$U@)GH|E%M z&Ki!#LO~uCkc=gV-1ot{(3VRl1qFrO!i}q99V0;?KTLMkzBe2n9tjNyj$tHtfYtr# z1=^Qxf32VA3h#=KzbYu$__S6JOlTu6!>9X)abM87C@kko((kxNr<)^No%*S25f${2 zPs;s=#Eusig4#zLIa1F1V}nCO%uGyr6%MdfQiHYgx`F*Ph53cAK6^S>AF_1(nFk-P zYGB{)D@_`39o)qpw<)I8vpedIy1mkK5)(4?>OTesU+ohb2J-C87CYV##IfnrgM_cU zDP9ib)PrO;j-jn}Ub9C(a=Whf$G34zMs0>RyFdSg+rm{A#a8oN!?xQ7G85V!`bZf` zwkBvRHgio{fHN4dt6`5(`kqYoZ*6U#`>UiE@mh>oG^kpP?`@TV(25!=1X)GN2;HxK zv1R$adIloi^hZ%>!XF1uPXo9V&x4KtU>$CZN6*kOxo~6%`hK4;&nQkY`Evs390QYL z4y~*hKmJ>)COv!4DCEU^vonaxIO%3NkFhJS?0I)JKH0OJFKZ@LwI$gLB!aD_2;tUV zZsq1qRdJ8ZPR`6Ir3v_n`F$wAzqXqGY}!-M`4XF0gI<5|Fom;X7(vL!%*=W{Vk?7i zu2XL-MSRy83Q`TQyz#*nv&zZ8X5o*p17G+;#UDD;Q5FPur zsor79t)(2w?&O3=dhSdq!iIP{zt(Jlc{<-L}#eQiagpLBsFjP{yXFU^IB z*w5_I$LK4d!~iMM3frfu*u+B+H!^K_H%x^UB08n);n7m4R@Hz!iK`-97!S_qQ206D z5|3L6sX>XPJv8iv!X(mz|2AdOhTOB2(2Yb98C*cp3BVAla~cJdSlW|Ns+XxcPSR$$ zFf&O892tgfA3i)#CK4UT6TWXqC#Ncx1hRT6int6@kiyTkL|XkTFNA|<0?Z<}lHb+c z8(P6xT;c0V(_zz08e5+4Z3JaHm5GSiJDuJKQN>*2fY=Yw5)35>r)%3Y;xI zC{PE%{R>9}7S_68fhhp1{`tocM&+w$jqFU*r`dvABonXp{`qG-94+#bk$ig%+sIz| z#w*x*Oy%?iGadFTjky7d6-C|jomdfL*#242|25_%xdhRry!htkMJi-)i(Xn;`Y{p{; zVsCSpbZ$bw+loWqHOnkF7A3qK!Zn9MwV$!JI4ziXVr_Z@x$Oq%%rJSR60x4yzI(2Bn$ST}V*00a~6h3(A4-W_i-(P#L z#>Hd4ni>!PQu~+wqt~BNq+GYn_;`-ZRZ=PN9{vGVs2ct*NQ$>FIGa>2o=nLiqNNKacz7=DAxA1Nw5i>}b+&j87$)%wh1_ zy#Bfec=S~!CvoXN9jrDBFso_(t2bJ0ijDn@Hyo}g7R)>Z{Hpf>NjuZgry?d>#&&JH zzsmKRjb|B3e;?2PeX$wSwL$lN{B3b?m>#rGEW3!!P_+H4qlOF&wwf;|XV)+_HI1=b zwVo>u>!&g9<1}S85c52utvjy=X2@bZGi8LM{n=?Qlw#x8!Bis;N}}yjoeZqw#TFev z&FVFDXuAHc?Pk9BLncD<+k9W&GdO|XA_)gb(`My5^>YB9a11kLGtJgt}KvYf~U z0@*iV5ti%E0Hnr4rxeoLV{*GhhtJ(=D0*Xqm_r!X<`p4p#|KBUJW-?P`+OJ%^HzlSV$N&3ray_Ht+O?eT?#RLph*&1q{_2hDkL3J zi8C%`*v*9E+cn^T{7Z@w%`-pi{Cr{BsGT2F74Wk3#iy;#epg;-$?q@uL{>!rxM&?# z80kR|1%81}Rhd={Ld%K+Q-^Ii{>{kwDBhU0ZYH~fRS)HIz4(xKTvW^K1MxTvG|&w3 zj!IwX59|=yGt#Fj<@!)%L$mTWOrQ+O9@qd1rfbY^dT)HJnZ zJTRE)J*tAgsL*57E7f(~h;8wZLNO@TYIOu`lx2l*MDmK(-E`hw zU%t(k;Np-R+^s(YDJMu6hDrvZkPP3;o&=Nr=)XGcpi=t@3PM7_P%Iu=WSYD2(~0(* z)GBUmokf;Uop%gq!oSvaLHqeX`YH{N^Z1xvF=`a1RcVyrMg!A*L) zP|4u?Mej8avET*Su6(4W>+bFC+~2Qa&jU2`27Y30yflX?TqbLtfG&b%T`>I#k6(~O zJHE!pJDFXtB8r6ic=k_BcSa>4(T4k)58*E4WTVaZNFBon{)-B$LT{A@-tgz>lckdB z7{BN3zPLE2aY2+|Mmr1&{*yS0mtw^t2zpnSa#2x^gtpdaPT#E_i%ic#Gh^yKr1mq1 zBw*&EatDD6$n^BIM+p=2>PGu!2x;c`O9m~eaD%vc<=@;_Kzj7ctR~mm3lb)91oH>} z0MtE?n?Q*@; zF7`ZjL>aI%*6IN_y_H&2S8$?+Rmy#FXK~_&I$DlhM$@$2ODB+) zNk)<7za&=L)*XV`Ppw^6BV*-#Kam3jynK?E?@+@)iHJ~%KACP7P{=|j<@>=2mf_#C zeV8gJmf6y1e=2VVd2Do=S8_Fo7wNmNPUJ|feZNvF8`5jDAZr|lS(MCj2~E&Pk18wu^2=CcKt$-$Y)WrC8PaJ&QFxJC+Q=CtAZu4e(3cq$o%DR7jyCq~9d31{;-;f@l_+E(c7U#Gi>aC^V``E<@qy_uR`4j(JKtY&oA)Iif*P0AYFvU)>`xK_WGg{4al)i5bKjhSK6iC$-I8x&&47(1}A)sU`1p)^C2m@}I!pMl!a1RWl z?k2P5MX9F8@}XZFQOr0JZO)yT<+#}xRNpyFUoCDe3h3b~+n3=JhshWn zNbbI0#-vH(ZKQr4np(rx`^13uy>POhgn)Jqs)-pfR$*$t!$*^sE1IT{t-oBoDv06s z@Zf~!E^QgL(9i%*U;vmF4G`P$OiP;?86l^Ejn0LpG)=(cHRFLoRVTYe6{rRIN#37% zq`%tKGf{8feq5wqZ}cLpelXBo{ZZ*jxj5>543VhEx&P}+C^{)b#|0aQ1Q-4br-Gv5 zXGi~l(E+W*(30b2D1{q^Nz+f_WJHG3e@m^T5yXRmSx*>b9i`y1ds5faTA(n2)~rhi ze;-0>zv6zr@oO{?!9r@xE)OrBipNnuW=!`rzPe~*?SC+4$<2Ev9x3wOy$>Tlmri7XBiC^o=MCi>xMGchcD~N5A;9R0F z+evH4^V(GVO=X>xl{p9oQM07;E^GvjLEmLRL8;Yl4OLbW#R6idQ36}5+A>mrCOuhp zJtV=P6T)Z%oaL)UpCiaW)NFRPmz^`j&`q4Nix0=$ic=pZEJaiG2C2``V1Z1Me_S^o zCU%*r0_D~wB7{Qt{C+Tw(Zq3lRGccl@ghe;Q|0Q;?4X$&R9ik%GX;avHgPx#lF_-f zN%byHn}W&0h%OyN&&C-#IbnsLqR5QUx1>8Hp(?4GU4t*ip7yJyOsS|!`q1zWG{Jn; zT?wHnzzV4}Qn4f^3>6{%rNt18f{#kWY{$%5R5<;NmC5H(zz6r7utPh0;I3dalgGoz zX?%q;4Lh|>8^(F?C4NZDibu6nlJ$vhA?lPK3SqDI9+kXo>2iUw%G0%TIYJ}A#H9K; z@+dEa%CZU6E7H89LLh%SbXA!+=k`<)<8k8Q3jQ8AF1lpR(;4xdv|A(YQ1QEevFoV8 z%xX&+Yi9Bc&-|?x`{uJgzH&A7f}MMk?0i*$Kj<|5ROi_W#A03B@~(E!%u2B9yg=;it1RMLWJkV zK1J6ljtZG=1^p0bl=(zIT)P$;h2R%NK8zu#n(ra5QoJ`Ia6@4Tu-Nv zS}Ob&Eq{u|UPiDY!*3(|qi#e{kx#Rmve*sY)TU5mF`}`uySXeP^EXA{5>e$*8Tzif z!D&gy(A&2N@;#Ez{ExP}%O~BHi*1I=VN&c-TpWi1v`F%1H?gZ>gJ>dQmxY12)~36k zQ89VB4yQ}CDqek6a%FivGp`ZV8u9wvlSzUqOOg6*LBtTJOxtH;5M zJg9`YfG({qSv(v~3SfKvbf?#SiucUl#jipFh$(fgW z`_*tcqz%oYAcXnbYO=6;+_lq3hV1JDqJnuroL4NBdslvw2E`)vE-G1uyRT?di2k{Y zIJquDQK?4%YPsUksZ~-((7NpxetW;2h7i<^i$c+SrMg4VB2Ll-JwwlxLw5~NrW;SNPWe!JEF zQF?gM>M=Gf%Eh>7K037_dYd@oGXQT3otoktaFk= zOC*F4GSt_|!f0<4!*=V)VrUIqgu_LuP?rn7%{RF;aectKpj3>7@i&WKxcDpI%d}AG zs7cnrR&tVzD8pX%Es&o~y4r2@E15-9LamA(GgFhKt)mK8=3!XtW)z7)AJu%g|G6xB znubOpBL_%5MGd(xw+j>KB8a$W&Zwz4=44l3fzoj@o%xGoqZD`LgQuMf(uYfwOpHG> zCL-uAFvb^>fji-a9VJmS&ws)PJ6vrvhRV|-R754TCDB##Nu4JS#D)*w7~=(3(qio| zvp$f!FzwXW(}}o(y2qRAl4J738HSC#OXyT~A=bURtHd2d(u*k#31Agi|vS zJ63Z2{aU;Al~z%e9@e)EP38ty`N|YcZ^F{FRQKk0#B6nAqi*z$z|U#d{-~hp?F+Z= z$`dEJFp4ui61dW-MhoZ4@r>svE;S`ld}YRbJSxGGA`SV1D3QeJs!vy>Lty++?Dh|s zQ+Xzd`sN9Uit4-vHrzv+I?S6|c?sTZh*t0tlQzgIP3f|uoWOq^j0mC#^CRqJSndi& zRnh7{(%VDJ-Lw#zYK2>AhP2~H(&C(aD~pRwzb1&^A#j*zPOn1ZmSBJw^x^_hMJQ{6 zl@3QcQzntBDSyh5=?GwdvHy}bp~?>)osv=irF)>-#l!^1rc#C#g2Ue2vXr>WGJ@Vw z@Kr!P$Vk$Fp~olO4?jK>5tbJa<>^fBy)yF%ji*VPSvi#s;GT4$7*wPlQH}PQI%M6{rH~c_em2aK@aK>R-uO&|dKP-Kci5#0%n5oa5jF6K0}>+v)x`EZ#Y`7r`B4%Rq9wc z+9w*o+jX`}ZU++}*FW}q<3+k=gzOs?P74EU$%iV?(Bn&F{&T~DcOe1Rls(53hro^S zvr})Eva!RJ8U}1yu*&BI6$BX>XYnVrqGm%C+M}Fe_}r=xB;#C{pI+3rjf$L$3;8iJ zKeAn$R3xVU>fEir?9~g_lZQ@~>C_V(OJPT>E)II%Z5hSg0ho|e9&^K4<`87@+oZtQ zfm3lB5`k!&nStmW{Y1G+M78LbcyZpuIaF`vRb$63ItbTxa=D8ck}Qg7`FrNS9+Zc4 z=ez8t9|MqMN`nI8c?~mhXIV;M`dqLxdw!zD_eK20>{0Z0u6NQDf#WfH4P<+YleK#2D#OQtwMLdz{hvLND}cB)TmvT&+>3sr;Kn>6sqp$yc?P!Fbs<$cVs{l3_u*h^&O{_70Bk=%dpb;}CVQl(cf+6K;k4;>@2w&qOkYeG-#JEjp zoBEY**GG3@DpGkUe+SA+ya%2IZ})4$auJ2w$Bna3+JvZ7Q2Y&jqZb~oPN%t z_+DXt`W27M{-qQ1+(puczkmDL0qgYW%b#Rcy&*&Hp9B<1DB|56+E8B~VOuh!;QhTl zCqiOI8e?$YR-)2ihx-lqlZReyE*`sfa=B&+HaCLpTFcF@Bo@s{ zz2H#sDZ>Z_1vEGmlnSzVo~^L-h5Sf*g4hq2=iNXwE9a|zM)i5_(%6c#r$u%3E@^0rqo})4rNt}*L5WcUACJS@ndsJ>BuDtRM@IKj zzwKPw##Oh1(P`#WkI}s}{n+>;J!3 z^I_b`?tpiCQj5-4_I#a)dni$Cp)uY6Qj3pn8P9)5 zjfW@}i_9k9cd*GlwA%tppO^pl{MeJ26|B2Q=VbobsnJHYL76!0Hv$r4M{a_;HjJ